diff --git a/examples/editor/index.html b/examples/editor/index.html
index 36526634..a3d19243 100644
--- a/examples/editor/index.html
+++ b/examples/editor/index.html
@@ -69,7 +69,7 @@
import * as THREE from "three";
import { OrbitControls } from "three/addons/controls/OrbitControls.js";
import { GUI } from "lil-gui";
- import { constructGrid, SparkControls, SparkRenderer, SplatMesh, textSplats, dyno, transcodeSpz, isMobile, isPcSogs, LN_SCALE_MIN, LN_SCALE_MAX } from "@sparkjsdev/spark";
+ import { constructGrid, SparkControls, SparkRenderer, SplatMesh, textSplats, dyno, transcodeSpz, isMobile, LN_SCALE_MIN, LN_SCALE_MAX } from "@sparkjsdev/spark";
import { getAssetFileURL } from "../js/get-asset-url.js";
const scene = new THREE.Scene();
diff --git a/examples/splat-painter/index.html b/examples/splat-painter/index.html
index 133135cb..0e56cf5c 100644
--- a/examples/splat-painter/index.html
+++ b/examples/splat-painter/index.html
@@ -62,7 +62,7 @@
RgbaArray,
readRgbaArray,
SparkControls,
- SpzWriter,
+ writeSpz,
unpackSplat,
PackedSplats,
} from "@sparkjsdev/spark";
@@ -405,7 +405,7 @@
alert("No splat mesh loaded for export.");
return;
}
-
+
try {
console.log("Starting SPZ export with painted changes...");
@@ -416,7 +416,7 @@
});
currentSplatMesh.updateGenerator();
}
-
+
const rgba = new RgbaArray();
rgba.render({
renderer,
@@ -470,7 +470,7 @@
unpacked.color.g = rgbaBytes[rgbaOffset + 1] / 255;
unpacked.color.b = rgbaBytes[rgbaOffset + 2] / 255;
unpacked.opacity = opacity;
-
+
// Push to new PackedSplats
newPackedSplats.pushSplat(
unpacked.center,
@@ -482,39 +482,13 @@
processedCount++;
}
-
+
console.log(`Processed ${processedCount} splats`);
// Now export the PackedSplats to SPZ
- console.log("Creating SPZ writer...");
- const maxSh = ioOptions.maxSh;
- const spzWriter = new SpzWriter({
- numSplats: nonZeroCount,
- shDegree: maxSh,
- fractionalBits: ioOptions.fractionalBits,
- flagAntiAlias: true,
- });
-
console.log("Writing splats to SPZ...");
- // Iterate through the new packed array
- for (let i = 0; i < nonZeroCount; i++) {
- const unpacked = unpackSplat(
- newPackedSplats.packedArray,
- i,
- newPackedSplats.splatEncoding
- );
-
- spzWriter.setCenter(i, unpacked.center.x, unpacked.center.y, unpacked.center.z);
- spzWriter.setScale(i, unpacked.scales.x, unpacked.scales.y, unpacked.scales.z);
- spzWriter.setQuat(i, unpacked.quaternion.x, unpacked.quaternion.y, unpacked.quaternion.z, unpacked.quaternion.w);
- spzWriter.setAlpha(i, unpacked.opacity);
- spzWriter.setRgb(i, unpacked.color.r, unpacked.color.g, unpacked.color.b);
- }
- const spzBytes = await spzWriter.finalize();
- if (spzWriter.clippedCount > 0) {
- console.log(`Clipped ${spzWriter.clippedCount} splats. Consider decreasing fractional-bits from ${ioOptions.fractionalBits} to reduce clipping.`);
- }
-
+ const { fileBytes: spzBytes } = writeSpz(newPackedSplats, ioOptions.maxSh, ioOptions.fractionalBits);
+
console.log("Creating download...");
const blob = new Blob([spzBytes], { type: "application/octet-stream" });
const url = URL.createObjectURL(blob);
@@ -531,7 +505,7 @@
}
},
};
-
+
const ioFolder = gui.addFolder("I/O");
ioFolder.add(ioOptions, "loadFile").name("Load Splats (SPZ/PLY)");
ioFolder.add(ioOptions, "saveToSpz").name("Save Splats (SPZ)");
diff --git a/rust/Cargo.lock b/rust/Cargo.lock
index e85117aa..625c7b67 100644
--- a/rust/Cargo.lock
+++ b/rust/Cargo.lock
@@ -1073,6 +1073,7 @@ dependencies = [
"itertools",
"js-sys",
"ordered-float",
+ "serde",
"serde-wasm-bindgen",
"serde_json",
"smallvec",
diff --git a/rust/spark-rs/Cargo.toml b/rust/spark-rs/Cargo.toml
index 7c4f3e13..bab51740 100644
--- a/rust/spark-rs/Cargo.toml
+++ b/rust/spark-rs/Cargo.toml
@@ -26,5 +26,6 @@ web-sys = { workspace = true, features = ["Window", "Performance"] }
spark-lib = { path = "../spark-lib" }
serde-wasm-bindgen.workspace = true
serde_json.workspace = true
+serde.workspace = true
itertools.workspace = true
console_error_panic_hook = "0.1"
diff --git a/rust/spark-rs/src/lib.rs b/rust/spark-rs/src/lib.rs
index 4aa1b10e..68d11d53 100644
--- a/rust/spark-rs/src/lib.rs
+++ b/rust/spark-rs/src/lib.rs
@@ -2,6 +2,8 @@
use std::cell::RefCell;
use js_sys::{Array, Float32Array, Object, Reflect, Uint8Array, Uint16Array, Uint32Array};
use spark_lib::decoder::{ChunkReceiver, MultiDecoder, SplatEncoding, SplatFileType, SplatGetter};
+use spark_lib::spz::SpzEncoder;
+use spark_lib::gsplat::{GsplatSH1,GsplatSH2,GsplatSH3};
use spark_lib::gsplat::GsplatArray as GsplatArrayInner;
use spark_lib::csplat::CsplatArray as CsplatArrayInner;
use spark_lib::tsplat::TsplatArray;
@@ -16,6 +18,9 @@ use raycast::{raycast_packed_ellipsoids, raycast_ext_ellipsoids};
mod sort;
use sort::{sort_internal, SortBuffers, sort32_internal, Sort32Buffers};
+mod transform;
+use transform::{transform_gsplatarray, TransformOptions};
+
mod decoder;
mod packed_splats;
mod ext_splats;
@@ -179,6 +184,7 @@ impl GsplatArray {
}
}
+
#[wasm_bindgen]
impl GsplatArray {
pub fn len(&self) -> usize {
@@ -257,6 +263,32 @@ impl GsplatArray {
pub fn inject_rgba8(&mut self, rgba: Uint8Array) {
self.inner.inject_rgba8(&rgba.to_vec());
}
+
+ pub fn transform(&mut self, transform: JsValue) -> Result<(), JsValue> {
+ let transform_options: TransformOptions = serde_wasm_bindgen::from_value(transform)?;
+ transform_gsplatarray(&mut self.inner, transform_options);
+ Ok(())
+ }
+
+ pub fn concat(&mut self, other: &mut GsplatArray) -> Result<(), JsValue> {
+ for i in 0..other.inner.len() {
+ let sh1 = if other.maxShDegree >= 1 { other.inner.sh1[i].clone() } else { GsplatSH1::default() };
+ let sh2 = if other.maxShDegree >= 2 { other.inner.sh2[i].clone() } else { GsplatSH2::default() };
+ let sh3 = if other.maxShDegree >= 3 { other.inner.sh3[i].clone() } else { GsplatSH3::default() };
+ self.inner.push_splat(other.inner.get(i).clone(), Some(sh1), Some(sh2), Some(sh3));
+ }
+ Ok(())
+ }
+
+ pub fn encode_to_spz(mut self, max_sh: u32, fractional_bits: u8) -> Result {
+ self.inner.clamp_sh_degree(max_sh as usize);
+ self.maxShDegree = self.inner.max_sh_degree;
+ let encoded = match SpzEncoder::new(self.inner).with_max_sh(max_sh as usize).with_fractional_bits(fractional_bits).encode() {
+ Err(err) => { return Err(JsValue::from(err.to_string())); },
+ Ok(encoded) => encoded
+ };
+ Ok(Uint8Array::from(encoded.as_slice()))
+ }
}
#[wasm_bindgen]
diff --git a/rust/spark-rs/src/transform.rs b/rust/spark-rs/src/transform.rs
new file mode 100644
index 00000000..773c649a
--- /dev/null
+++ b/rust/spark-rs/src/transform.rs
@@ -0,0 +1,73 @@
+use glam::{Vec3A, Quat};
+
+use spark_lib::gsplat::GsplatArray;
+use spark_lib::tsplat::TsplatArray;
+use spark_lib::tsplat::Tsplat;
+use serde::{Deserialize, Serialize};
+
+use spark_lib::decoder::SplatReceiver;
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct TransformOptions {
+ pub translation: [f32; 3],
+ pub rotation: [f32; 4],
+ pub scale: f32,
+ pub clip: Option<[f32; 6]>,
+ #[serde(rename = "opacityThreshold")]
+ pub opacity_threshold: f32,
+}
+
+pub fn transform_gsplatarray(gsplats: &mut GsplatArray, transform_options: TransformOptions) {
+ let translation = Vec3A::from_array(transform_options.translation);
+ let quaternion = Quat::from_array(transform_options.rotation);
+ let scale = Vec3A::splat(transform_options.scale);
+
+ let clip = transform_options.clip.map(|clip| (Vec3A::from_slice(&clip[..3]), Vec3A::from_slice(&clip[3..])));
+
+ let mut out_index = 0;
+ for splat_index in 0..gsplats.splats.len() {
+ let in_splat = gsplats.get(splat_index);
+
+ let mut center = in_splat.center();
+ // Transform center
+ center = quaternion * (center * scale) + translation;
+
+ // Check clip box
+ let clipped = match clip {
+ Some((min, max)) => (center.cmplt(min)).any() || (center.cmpgt(max)).any(),
+ None => false
+ };
+ if clipped {
+ continue;
+ }
+
+ // Check opacity threshold
+ let opacity = in_splat.opacity();
+ if opacity < transform_options.opacity_threshold {
+ continue;
+ }
+
+ let mut scales = in_splat.scales();
+ let mut quat = in_splat.quaternion();
+ let rgb = in_splat.rgb();
+
+ gsplats.set_center(out_index, 1, ¢er.to_array());
+
+ scales *= scale;
+ gsplats.set_scale(out_index, 1, &scales.to_array());
+
+ quat *= quaternion;
+ gsplats.set_quat(out_index, 1, &quat.to_array());
+
+ gsplats.set_rgb(out_index, 1, &rgb.to_array());
+ gsplats.set_opacity(out_index, 1, &[opacity]);
+
+ gsplats.set_sh1(out_index, 1, gsplats.get_sh1(splat_index).as_slice());
+ gsplats.set_sh2(out_index, 1, gsplats.get_sh2(splat_index).as_slice());
+ gsplats.set_sh3(out_index, 1, gsplats.get_sh3(splat_index).as_slice());
+
+ out_index += 1;
+ }
+
+ gsplats.truncate(out_index);
+}
\ No newline at end of file
diff --git a/src/SplatLoader.ts b/src/SplatLoader.ts
index b5866099..c2a509e2 100644
--- a/src/SplatLoader.ts
+++ b/src/SplatLoader.ts
@@ -5,7 +5,6 @@ import { PackedSplats, type PackedSplatsOptions } from "./PackedSplats";
import { SplatMesh } from "./SplatMesh";
import { workerPool } from "./SplatWorker";
import { type SplatEncoding, SplatFileType } from "./defines";
-import { PlyReader } from "./ply";
import { decompressPartialGzip, getTextureSize } from "./utils";
// SplatLoader implements the THREE.Loader interface and supports loading a variety
@@ -340,65 +339,6 @@ export class SplatLoader extends Loader {
}
}
-async function fetchWithProgress(
- request: Request,
- onProgress?: (event: ProgressEvent) => void,
-) {
- const response = await fetch(request);
- if (!response.ok) {
- throw new Error(
- `${response.status} "${response.statusText}" fetching URL: ${request.url}`,
- );
- }
- if (!response.body) {
- throw new Error(`Response body is null for URL: ${request.url}`);
- }
-
- const reader = response.body.getReader();
- let loaded = 0;
- const chunks: Uint8Array[] = [];
- try {
- const contentLength = Number.parseInt(
- response.headers.get("Content-Length") || "0",
- );
- const total = Number.isNaN(contentLength) ? 0 : contentLength;
-
- while (true) {
- const { done, value } = await reader.read();
- if (done) {
- break;
- }
- chunks.push(value);
- loaded += value.length;
-
- if (onProgress) {
- onProgress(
- new ProgressEvent("progress", {
- lengthComputable: total !== 0,
- loaded,
- total,
- }),
- );
- }
- }
- } catch (err) {
- try {
- const reason = err instanceof Error ? err.message : "Unknown error";
- await reader.cancel(reason);
- } catch {}
- throw err;
- }
-
- // Combine chunks into a single buffer
- const bytes = new Uint8Array(loaded);
- let offset = 0;
- for (const chunk of chunks) {
- bytes.set(chunk, offset);
- offset += chunk.length;
- }
- return bytes.buffer;
-}
-
export function getSplatFileType(
fileBytes: Uint8Array,
): SplatFileType | undefined {
diff --git a/src/antisplat.ts b/src/antisplat.ts
deleted file mode 100644
index 8b52864d..00000000
--- a/src/antisplat.ts
+++ /dev/null
@@ -1,125 +0,0 @@
-import type { SplatEncoding } from "./defines";
-import { computeMaxSplats, setPackedSplat } from "./utils";
-
-export function decodeAntiSplat(
- fileBytes: Uint8Array,
- initNumSplats: (numSplats: number) => void,
- splatCallback: (
- index: number,
- x: number,
- y: number,
- z: number,
- scaleX: number,
- scaleY: number,
- scaleZ: number,
- quatX: number,
- quatY: number,
- quatZ: number,
- quatW: number,
- opacity: number,
- r: number,
- g: number,
- b: number,
- ) => void,
-) {
- const numSplats = Math.floor(fileBytes.length / 32); // 32 bytes per splat
- if (numSplats * 32 !== fileBytes.length) {
- throw new Error("Invalid .splat file size");
- }
- initNumSplats(numSplats);
-
- const f32 = new Float32Array(fileBytes.buffer);
- for (let i = 0; i < numSplats; ++i) {
- const i32 = i * 32;
- const i8 = i * 8;
- const x = f32[i8 + 0];
- const y = f32[i8 + 1];
- const z = f32[i8 + 2];
- const scaleX = f32[i8 + 3];
- const scaleY = f32[i8 + 4];
- const scaleZ = f32[i8 + 5];
- const r = fileBytes[i32 + 24] / 255;
- const g = fileBytes[i32 + 25] / 255;
- const b = fileBytes[i32 + 26] / 255;
- const opacity = fileBytes[i32 + 27] / 255;
- const quatW = (fileBytes[i32 + 28] - 128) / 128;
- const quatX = (fileBytes[i32 + 29] - 128) / 128;
- const quatY = (fileBytes[i32 + 30] - 128) / 128;
- const quatZ = (fileBytes[i32 + 31] - 128) / 128;
- splatCallback(
- i,
- x,
- y,
- z,
- scaleX,
- scaleY,
- scaleZ,
- quatX,
- quatY,
- quatZ,
- quatW,
- opacity,
- r,
- g,
- b,
- );
- }
-}
-
-export function unpackAntiSplat(
- fileBytes: Uint8Array,
- splatEncoding: SplatEncoding,
-): {
- packedArray: Uint32Array;
- numSplats: number;
-} {
- let numSplats = 0;
- let maxSplats = 0;
- let packedArray = new Uint32Array(0);
- decodeAntiSplat(
- fileBytes,
- (cbNumSplats) => {
- numSplats = cbNumSplats;
- maxSplats = computeMaxSplats(numSplats);
- packedArray = new Uint32Array(maxSplats * 4);
- },
- (
- index,
- x,
- y,
- z,
- scaleX,
- scaleY,
- scaleZ,
- quatX,
- quatY,
- quatZ,
- quatW,
- opacity,
- r,
- g,
- b,
- ) => {
- setPackedSplat(
- packedArray,
- index,
- x,
- y,
- z,
- scaleX,
- scaleY,
- scaleZ,
- quatX,
- quatY,
- quatZ,
- quatW,
- opacity,
- r,
- g,
- b,
- splatEncoding,
- );
- },
- );
- return { packedArray, numSplats };
-}
diff --git a/src/index.ts b/src/index.ts
index 4a68545e..2ce39cc9 100644
--- a/src/index.ts
+++ b/src/index.ts
@@ -11,10 +11,8 @@ export { RgbaArray, readRgbaArray } from "./RgbaArray";
export {
SplatLoader,
getSplatFileType,
- isPcSogs,
} from "./SplatLoader";
-export { PlyReader } from "./ply";
-export { SpzReader, SpzWriter, transcodeSpz } from "./spz";
+export { transcodeSpz, writeSpz } from "./spz";
export { PackedSplats, type PackedSplatsOptions } from "./PackedSplats";
export { ExtSplats, type ExtSplatsOptions } from "./ExtSplats";
diff --git a/src/ksplat.ts b/src/ksplat.ts
deleted file mode 100644
index 3c46e58e..00000000
--- a/src/ksplat.ts
+++ /dev/null
@@ -1,636 +0,0 @@
-import type { SplatEncoding } from "./defines";
-import {
- computeMaxSplats,
- encodeSh1Rgb,
- encodeSh2Rgb,
- encodeSh3Rgb,
- fromHalf,
- setPackedSplat,
-} from "./utils";
-
-type KsplatCompression = {
- bytesPerCenter: number;
- bytesPerScale: number;
- bytesPerRotation: number;
- bytesPerColor: number;
- bytesPerSphericalHarmonicsComponent: number;
- scaleOffsetBytes: number;
- rotationOffsetBytes: number;
- colorOffsetBytes: number;
- sphericalHarmonicsOffsetBytes: number;
- scaleRange: number;
-};
-
-const KSPLAT_COMPRESSION: Record = {
- 0: {
- bytesPerCenter: 12,
- bytesPerScale: 12,
- bytesPerRotation: 16,
- bytesPerColor: 4,
- bytesPerSphericalHarmonicsComponent: 4,
- scaleOffsetBytes: 12,
- rotationOffsetBytes: 24,
- colorOffsetBytes: 40,
- sphericalHarmonicsOffsetBytes: 44,
- scaleRange: 1,
- },
- 1: {
- bytesPerCenter: 6,
- bytesPerScale: 6,
- bytesPerRotation: 8,
- bytesPerColor: 4,
- bytesPerSphericalHarmonicsComponent: 2,
- scaleOffsetBytes: 6,
- rotationOffsetBytes: 12,
- colorOffsetBytes: 20,
- sphericalHarmonicsOffsetBytes: 24,
- scaleRange: 32767,
- },
- 2: {
- bytesPerCenter: 6,
- bytesPerScale: 6,
- bytesPerRotation: 8,
- bytesPerColor: 4,
- bytesPerSphericalHarmonicsComponent: 1,
- scaleOffsetBytes: 6,
- rotationOffsetBytes: 12,
- colorOffsetBytes: 20,
- sphericalHarmonicsOffsetBytes: 24,
- scaleRange: 32767,
- },
-};
-
-const KSPLAT_SH_DEGREE_TO_COMPONENTS: Record = {
- 0: 0,
- 1: 9,
- 2: 24,
- 3: 45,
-};
-
-export function decodeKsplat(
- fileBytes: Uint8Array,
- initNumSplats: (numSplats: number) => void,
- splatCallback: (
- index: number,
- x: number,
- y: number,
- z: number,
- scaleX: number,
- scaleY: number,
- scaleZ: number,
- quatX: number,
- quatY: number,
- quatZ: number,
- quatW: number,
- opacity: number,
- r: number,
- g: number,
- b: number,
- ) => void,
- shCallback?: (
- index: number,
- sh1: Float32Array,
- sh2?: Float32Array,
- sh3?: Float32Array,
- ) => void,
-) {
- const HEADER_BYTES = 4096;
- const SECTION_BYTES = 1024;
-
- let headerOffset = 0;
- const header = new DataView(fileBytes.buffer, headerOffset, HEADER_BYTES);
- headerOffset += HEADER_BYTES;
-
- const versionMajor = header.getUint8(0);
- const versionMinor = header.getUint8(1);
- if (versionMajor !== 0 || versionMinor < 1) {
- throw new Error(
- `Unsupported .ksplat version: ${versionMajor}.${versionMinor}`,
- );
- }
- const maxSectionCount = header.getUint32(4, true);
- // const sectionCount = header.getUint32(8, true);
- // const maxSplatCount = header.getUint32(12, true);
- const splatCount = header.getUint32(16, true);
- const compressionLevel = header.getUint16(20, true);
- if (compressionLevel < 0 || compressionLevel > 2) {
- throw new Error(`Invalid .ksplat compression level: ${compressionLevel}`);
- }
- // const sceneCenterX = header.getFloat32(24, true);
- // const sceneCenterY = header.getFloat32(28, true);
- // const sceneCenterZ = header.getFloat32(32, true);
- const minSphericalHarmonicsCoeff = header.getFloat32(36, true) || -1.5;
- const maxSphericalHarmonicsCoeff = header.getFloat32(40, true) || 1.5;
-
- const numSplats = splatCount;
- initNumSplats(numSplats);
- const maxSplats = computeMaxSplats(numSplats);
- const packedArray = new Uint32Array(maxSplats * 4);
- const extra: Record = {};
-
- let sectionBase = HEADER_BYTES + maxSectionCount * SECTION_BYTES;
-
- for (let section = 0; section < maxSectionCount; ++section) {
- const section = new DataView(fileBytes.buffer, headerOffset, SECTION_BYTES);
- headerOffset += SECTION_BYTES;
-
- const sectionSplatCount = section.getUint32(0, true);
- const sectionMaxSplatCount = section.getUint32(4, true);
- const bucketSize = section.getUint32(8, true);
- const bucketCount = section.getUint32(12, true);
- const bucketBlockSize = section.getFloat32(16, true);
- const bucketStorageSizeBytes = section.getUint16(20, true);
- const compressionScaleRange =
- (section.getUint32(24, true) ||
- KSPLAT_COMPRESSION[compressionLevel]?.scaleRange) ??
- 1;
- const fullBucketCount = section.getUint32(32, true);
- const fullBucketSplats = fullBucketCount * bucketSize;
- const partiallyFilledBucketCount = section.getUint32(36, true);
- const bucketsMetaDataSizeBytes = partiallyFilledBucketCount * 4;
- const bucketsStorageSizeBytes =
- bucketStorageSizeBytes * bucketCount + bucketsMetaDataSizeBytes;
- const sphericalHarmonicsDegree = section.getUint16(40, true);
- const shComponents =
- KSPLAT_SH_DEGREE_TO_COMPONENTS[sphericalHarmonicsDegree];
-
- const {
- bytesPerCenter,
- bytesPerScale,
- bytesPerRotation,
- bytesPerColor,
- bytesPerSphericalHarmonicsComponent,
- scaleOffsetBytes,
- rotationOffsetBytes,
- colorOffsetBytes,
- sphericalHarmonicsOffsetBytes,
- } = KSPLAT_COMPRESSION[compressionLevel];
- const bytesPerSplat =
- bytesPerCenter +
- bytesPerScale +
- bytesPerRotation +
- bytesPerColor +
- shComponents * bytesPerSphericalHarmonicsComponent;
- const splatDataStorageSizeBytes = bytesPerSplat * sectionMaxSplatCount;
- const storageSizeBytes =
- splatDataStorageSizeBytes + bucketsStorageSizeBytes;
-
- const sh1Index = [0, 3, 6, 1, 4, 7, 2, 5, 8];
- const sh2Index = [
- 9, 14, 19, 10, 15, 20, 11, 16, 21, 12, 17, 22, 13, 18, 23,
- ];
- const sh3Index = [
- 24, 31, 38, 25, 32, 39, 26, 33, 40, 27, 34, 41, 28, 35, 42, 29, 36, 43,
- 30, 37, 44,
- ];
- const sh1 =
- sphericalHarmonicsDegree >= 1 ? new Float32Array(3 * 3) : undefined;
- const sh2 =
- sphericalHarmonicsDegree >= 2 ? new Float32Array(5 * 3) : undefined;
- const sh3 =
- sphericalHarmonicsDegree >= 3 ? new Float32Array(7 * 3) : undefined;
-
- const compressionScaleFactor = bucketBlockSize / 2 / compressionScaleRange;
- const bucketsBase = sectionBase + bucketsMetaDataSizeBytes;
- const dataBase = sectionBase + bucketsStorageSizeBytes;
- const data = new DataView(
- fileBytes.buffer,
- dataBase,
- splatDataStorageSizeBytes,
- );
- const bucketArray = new Float32Array(
- fileBytes.buffer,
- bucketsBase,
- bucketCount * 3,
- );
- const partiallyFilledBucketLengths = new Uint32Array(
- fileBytes.buffer,
- sectionBase,
- partiallyFilledBucketCount,
- );
-
- function getSh(splatOffset: number, component: number) {
- if (compressionLevel === 0) {
- return data.getFloat32(
- splatOffset + sphericalHarmonicsOffsetBytes + component * 4,
- true,
- );
- }
- if (compressionLevel === 1) {
- return fromHalf(
- data.getUint16(
- splatOffset + sphericalHarmonicsOffsetBytes + component * 2,
- true,
- ),
- );
- }
- const t =
- data.getUint8(splatOffset + sphericalHarmonicsOffsetBytes + component) /
- 255;
- return (
- minSphericalHarmonicsCoeff +
- t * (maxSphericalHarmonicsCoeff - minSphericalHarmonicsCoeff)
- );
- }
-
- let partialBucketIndex = fullBucketCount;
- let partialBucketBase = fullBucketSplats;
-
- for (let i = 0; i < sectionSplatCount; ++i) {
- const splatOffset = i * bytesPerSplat;
-
- let bucketIndex: number;
- if (i < fullBucketSplats) {
- bucketIndex = Math.floor(i / bucketSize);
- } else {
- const bucketLength =
- partiallyFilledBucketLengths[partialBucketIndex - fullBucketCount];
- if (i >= partialBucketBase + bucketLength) {
- partialBucketIndex += 1;
- partialBucketBase += bucketLength;
- }
- bucketIndex = partialBucketIndex;
- }
-
- const x =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + 0, true)
- : (data.getUint16(splatOffset + 0, true) - compressionScaleRange) *
- compressionScaleFactor +
- bucketArray[3 * bucketIndex + 0];
- const y =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + 4, true)
- : (data.getUint16(splatOffset + 2, true) - compressionScaleRange) *
- compressionScaleFactor +
- bucketArray[3 * bucketIndex + 1];
- const z =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + 8, true)
- : (data.getUint16(splatOffset + 4, true) - compressionScaleRange) *
- compressionScaleFactor +
- bucketArray[3 * bucketIndex + 2];
-
- const scaleX =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + scaleOffsetBytes + 0, true)
- : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 0, true));
- const scaleY =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + scaleOffsetBytes + 4, true)
- : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 2, true));
- const scaleZ =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + scaleOffsetBytes + 8, true)
- : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 4, true));
-
- const quatW =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + rotationOffsetBytes + 0, true)
- : fromHalf(
- data.getUint16(splatOffset + rotationOffsetBytes + 0, true),
- );
- const quatX =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + rotationOffsetBytes + 4, true)
- : fromHalf(
- data.getUint16(splatOffset + rotationOffsetBytes + 2, true),
- );
- const quatY =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + rotationOffsetBytes + 8, true)
- : fromHalf(
- data.getUint16(splatOffset + rotationOffsetBytes + 4, true),
- );
- const quatZ =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + rotationOffsetBytes + 12, true)
- : fromHalf(
- data.getUint16(splatOffset + rotationOffsetBytes + 6, true),
- );
-
- const r = data.getUint8(splatOffset + colorOffsetBytes + 0) / 255;
- const g = data.getUint8(splatOffset + colorOffsetBytes + 1) / 255;
- const b = data.getUint8(splatOffset + colorOffsetBytes + 2) / 255;
- const opacity = data.getUint8(splatOffset + colorOffsetBytes + 3) / 255;
-
- splatCallback(
- i,
- x,
- y,
- z,
- scaleX,
- scaleY,
- scaleZ,
- quatX,
- quatY,
- quatZ,
- quatW,
- opacity,
- r,
- g,
- b,
- );
-
- if (sphericalHarmonicsDegree >= 1 && sh1) {
- for (const [i, key] of sh1Index.entries()) {
- sh1[i] = getSh(splatOffset, key);
- }
- if (sh2) {
- for (const [i, key] of sh2Index.entries()) {
- sh2[i] = getSh(splatOffset, key);
- }
- }
- if (sh3) {
- for (const [i, key] of sh3Index.entries()) {
- sh3[i] = getSh(splatOffset, key);
- }
- }
- shCallback?.(i, sh1, sh2, sh3);
- }
- }
- sectionBase += storageSizeBytes;
- }
-}
-
-export function unpackKsplat(
- fileBytes: Uint8Array,
- splatEncoding: SplatEncoding,
-): {
- packedArray: Uint32Array;
- numSplats: number;
- extra: Record;
-} {
- const HEADER_BYTES = 4096;
- const SECTION_BYTES = 1024;
-
- let headerOffset = 0;
- const header = new DataView(fileBytes.buffer, headerOffset, HEADER_BYTES);
- headerOffset += HEADER_BYTES;
-
- const versionMajor = header.getUint8(0);
- const versionMinor = header.getUint8(1);
- if (versionMajor !== 0 || versionMinor < 1) {
- throw new Error(
- `Unsupported .ksplat version: ${versionMajor}.${versionMinor}`,
- );
- }
- const maxSectionCount = header.getUint32(4, true);
- // const sectionCount = header.getUint32(8, true);
- // const maxSplatCount = header.getUint32(12, true);
- const splatCount = header.getUint32(16, true);
- const compressionLevel = header.getUint16(20, true);
- if (compressionLevel < 0 || compressionLevel > 2) {
- throw new Error(`Invalid .ksplat compression level: ${compressionLevel}`);
- }
- // const sceneCenterX = header.getFloat32(24, true);
- // const sceneCenterY = header.getFloat32(28, true);
- // const sceneCenterZ = header.getFloat32(32, true);
- const minSphericalHarmonicsCoeff = header.getFloat32(36, true) || -1.5;
- const maxSphericalHarmonicsCoeff = header.getFloat32(40, true) || 1.5;
-
- const numSplats = splatCount;
- const maxSplats = computeMaxSplats(numSplats);
- const packedArray = new Uint32Array(maxSplats * 4);
- const extra: Record = {};
-
- let sectionBase = HEADER_BYTES + maxSectionCount * SECTION_BYTES;
-
- for (let section = 0; section < maxSectionCount; ++section) {
- const section = new DataView(fileBytes.buffer, headerOffset, SECTION_BYTES);
- headerOffset += SECTION_BYTES;
-
- const sectionSplatCount = section.getUint32(0, true);
- const sectionMaxSplatCount = section.getUint32(4, true);
- const bucketSize = section.getUint32(8, true);
- const bucketCount = section.getUint32(12, true);
- const bucketBlockSize = section.getFloat32(16, true);
- const bucketStorageSizeBytes = section.getUint16(20, true);
- const compressionScaleRange =
- (section.getUint32(24, true) ||
- KSPLAT_COMPRESSION[compressionLevel]?.scaleRange) ??
- 1;
- const fullBucketCount = section.getUint32(32, true);
- const fullBucketSplats = fullBucketCount * bucketSize;
- const partiallyFilledBucketCount = section.getUint32(36, true);
- const bucketsMetaDataSizeBytes = partiallyFilledBucketCount * 4;
- const bucketsStorageSizeBytes =
- bucketStorageSizeBytes * bucketCount + bucketsMetaDataSizeBytes;
- const sphericalHarmonicsDegree = section.getUint16(40, true);
- const shComponents =
- KSPLAT_SH_DEGREE_TO_COMPONENTS[sphericalHarmonicsDegree];
-
- const {
- bytesPerCenter,
- bytesPerScale,
- bytesPerRotation,
- bytesPerColor,
- bytesPerSphericalHarmonicsComponent,
- scaleOffsetBytes,
- rotationOffsetBytes,
- colorOffsetBytes,
- sphericalHarmonicsOffsetBytes,
- } = KSPLAT_COMPRESSION[compressionLevel];
- const bytesPerSplat =
- bytesPerCenter +
- bytesPerScale +
- bytesPerRotation +
- bytesPerColor +
- shComponents * bytesPerSphericalHarmonicsComponent;
- const splatDataStorageSizeBytes = bytesPerSplat * sectionMaxSplatCount;
- const storageSizeBytes =
- splatDataStorageSizeBytes + bucketsStorageSizeBytes;
-
- const sh1Index = [0, 3, 6, 1, 4, 7, 2, 5, 8];
- const sh2Index = [
- 9, 14, 19, 10, 15, 20, 11, 16, 21, 12, 17, 22, 13, 18, 23,
- ];
- const sh3Index = [
- 24, 31, 38, 25, 32, 39, 26, 33, 40, 27, 34, 41, 28, 35, 42, 29, 36, 43,
- 30, 37, 44,
- ];
- const sh1 =
- sphericalHarmonicsDegree >= 1 ? new Float32Array(3 * 3) : undefined;
- const sh2 =
- sphericalHarmonicsDegree >= 2 ? new Float32Array(5 * 3) : undefined;
- const sh3 =
- sphericalHarmonicsDegree >= 3 ? new Float32Array(7 * 3) : undefined;
-
- const compressionScaleFactor = bucketBlockSize / 2 / compressionScaleRange;
- const bucketsBase = sectionBase + bucketsMetaDataSizeBytes;
- const dataBase = sectionBase + bucketsStorageSizeBytes;
- const data = new DataView(
- fileBytes.buffer,
- dataBase,
- splatDataStorageSizeBytes,
- );
- const bucketArray = new Float32Array(
- fileBytes.buffer,
- bucketsBase,
- bucketCount * 3,
- );
- const partiallyFilledBucketLengths = new Uint32Array(
- fileBytes.buffer,
- sectionBase,
- partiallyFilledBucketCount,
- );
-
- function getSh(splatOffset: number, component: number) {
- if (compressionLevel === 0) {
- return data.getFloat32(
- splatOffset + sphericalHarmonicsOffsetBytes + component * 4,
- true,
- );
- }
- if (compressionLevel === 1) {
- return fromHalf(
- data.getUint16(
- splatOffset + sphericalHarmonicsOffsetBytes + component * 2,
- true,
- ),
- );
- }
- const t =
- data.getUint8(splatOffset + sphericalHarmonicsOffsetBytes + component) /
- 255;
- return (
- minSphericalHarmonicsCoeff +
- t * (maxSphericalHarmonicsCoeff - minSphericalHarmonicsCoeff)
- );
- }
-
- let partialBucketIndex = fullBucketCount;
- let partialBucketBase = fullBucketSplats;
-
- for (let i = 0; i < sectionSplatCount; ++i) {
- const splatOffset = i * bytesPerSplat;
-
- let bucketIndex: number;
- if (i < fullBucketSplats) {
- bucketIndex = Math.floor(i / bucketSize);
- } else {
- const bucketLength =
- partiallyFilledBucketLengths[partialBucketIndex - fullBucketCount];
- if (i >= partialBucketBase + bucketLength) {
- partialBucketIndex += 1;
- partialBucketBase += bucketLength;
- }
- bucketIndex = partialBucketIndex;
- }
-
- const x =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + 0, true)
- : (data.getUint16(splatOffset + 0, true) - compressionScaleRange) *
- compressionScaleFactor +
- bucketArray[3 * bucketIndex + 0];
- const y =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + 4, true)
- : (data.getUint16(splatOffset + 2, true) - compressionScaleRange) *
- compressionScaleFactor +
- bucketArray[3 * bucketIndex + 1];
- const z =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + 8, true)
- : (data.getUint16(splatOffset + 4, true) - compressionScaleRange) *
- compressionScaleFactor +
- bucketArray[3 * bucketIndex + 2];
-
- const scaleX =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + scaleOffsetBytes + 0, true)
- : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 0, true));
- const scaleY =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + scaleOffsetBytes + 4, true)
- : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 2, true));
- const scaleZ =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + scaleOffsetBytes + 8, true)
- : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 4, true));
-
- const quatW =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + rotationOffsetBytes + 0, true)
- : fromHalf(
- data.getUint16(splatOffset + rotationOffsetBytes + 0, true),
- );
- const quatX =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + rotationOffsetBytes + 4, true)
- : fromHalf(
- data.getUint16(splatOffset + rotationOffsetBytes + 2, true),
- );
- const quatY =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + rotationOffsetBytes + 8, true)
- : fromHalf(
- data.getUint16(splatOffset + rotationOffsetBytes + 4, true),
- );
- const quatZ =
- compressionLevel === 0
- ? data.getFloat32(splatOffset + rotationOffsetBytes + 12, true)
- : fromHalf(
- data.getUint16(splatOffset + rotationOffsetBytes + 6, true),
- );
-
- const r = data.getUint8(splatOffset + colorOffsetBytes + 0) / 255;
- const g = data.getUint8(splatOffset + colorOffsetBytes + 1) / 255;
- const b = data.getUint8(splatOffset + colorOffsetBytes + 2) / 255;
- const opacity = data.getUint8(splatOffset + colorOffsetBytes + 3) / 255;
-
- setPackedSplat(
- packedArray,
- i,
- x,
- y,
- z,
- scaleX,
- scaleY,
- scaleZ,
- quatX,
- quatY,
- quatZ,
- quatW,
- opacity,
- r,
- g,
- b,
- splatEncoding,
- );
-
- if (sphericalHarmonicsDegree >= 1) {
- if (sh1) {
- if (!extra.sh1) {
- extra.sh1 = new Uint32Array(numSplats * 2);
- }
- for (const [i, key] of sh1Index.entries()) {
- sh1[i] = getSh(splatOffset, key);
- }
- encodeSh1Rgb(extra.sh1 as Uint32Array, i, sh1, splatEncoding);
- }
- if (sh2) {
- if (!extra.sh2) {
- extra.sh2 = new Uint32Array(numSplats * 4);
- }
- for (const [i, key] of sh2Index.entries()) {
- sh2[i] = getSh(splatOffset, key);
- }
- encodeSh2Rgb(extra.sh2 as Uint32Array, i, sh2, splatEncoding);
- }
- if (sh3) {
- if (!extra.sh3) {
- extra.sh3 = new Uint32Array(numSplats * 4);
- }
- for (const [i, key] of sh3Index.entries()) {
- sh3[i] = getSh(splatOffset, key);
- }
- encodeSh3Rgb(extra.sh3 as Uint32Array, i, sh3, splatEncoding);
- }
- }
- }
- sectionBase += storageSizeBytes;
- }
- return { packedArray, numSplats, extra };
-}
diff --git a/src/pcsogs.ts b/src/pcsogs.ts
deleted file mode 100644
index ca883a71..00000000
--- a/src/pcsogs.ts
+++ /dev/null
@@ -1,387 +0,0 @@
-import { unzip } from "fflate";
-import {
- type PcSogsJson,
- type PcSogsV2Json,
- tryPcSogsZip,
-} from "./SplatLoader";
-import type { SplatEncoding } from "./defines";
-import {
- computeMaxSplats,
- encodeSh1Rgb,
- encodeSh2Rgb,
- encodeSh3Rgb,
- setPackedSplatCenter,
- setPackedSplatQuat,
- setPackedSplatRgba,
- setPackedSplatScales,
-} from "./utils";
-
-export async function unpackPcSogs(
- json: PcSogsJson | PcSogsV2Json,
- extraFiles: Record,
- splatEncoding: SplatEncoding,
-): Promise<{
- packedArray: Uint32Array;
- numSplats: number;
- extra: Record;
-}> {
- const isVersion2 = "version" in json;
-
- if (!isVersion2 && json.quats.encoding !== "quaternion_packed") {
- throw new Error("Unsupported quaternion encoding");
- }
-
- const numSplats = isVersion2 ? json.count : json.means.shape[0];
- const maxSplats = computeMaxSplats(numSplats);
- const packedArray = new Uint32Array(maxSplats * 4);
- const extra: Record = {};
-
- const meansPromise = Promise.all([
- decodeImageRgba(extraFiles[json.means.files[0]]),
- decodeImageRgba(extraFiles[json.means.files[1]]),
- ]).then((means) => {
- for (let i = 0; i < numSplats; ++i) {
- const i4 = i * 4;
- const fx = (means[0][i4 + 0] + (means[1][i4 + 0] << 8)) / 65535;
- const fy = (means[0][i4 + 1] + (means[1][i4 + 1] << 8)) / 65535;
- const fz = (means[0][i4 + 2] + (means[1][i4 + 2] << 8)) / 65535;
- let x =
- json.means.mins[0] + (json.means.maxs[0] - json.means.mins[0]) * fx;
- let y =
- json.means.mins[1] + (json.means.maxs[1] - json.means.mins[1]) * fy;
- let z =
- json.means.mins[2] + (json.means.maxs[2] - json.means.mins[2]) * fz;
- x = Math.sign(x) * (Math.exp(Math.abs(x)) - 1);
- y = Math.sign(y) * (Math.exp(Math.abs(y)) - 1);
- z = Math.sign(z) * (Math.exp(Math.abs(z)) - 1);
- setPackedSplatCenter(packedArray, i, x, y, z);
- }
- });
-
- const scalesPromise = decodeImageRgba(extraFiles[json.scales.files[0]]).then(
- (scales) => {
- let xLookup: number[];
- let yLookup: number[];
- let zLookup: number[];
-
- if (isVersion2) {
- xLookup =
- yLookup =
- zLookup =
- json.scales.codebook.map((x) => Math.exp(x));
- } else {
- xLookup = new Array(256)
- .fill(0)
- .map(
- (_, i) =>
- json.scales.mins[0] +
- (json.scales.maxs[0] - json.scales.mins[0]) * (i / 255),
- )
- .map((x) => Math.exp(x));
- yLookup = new Array(256)
- .fill(0)
- .map(
- (_, i) =>
- json.scales.mins[1] +
- (json.scales.maxs[1] - json.scales.mins[1]) * (i / 255),
- )
- .map((x) => Math.exp(x));
- zLookup = new Array(256)
- .fill(0)
- .map(
- (_, i) =>
- json.scales.mins[2] +
- (json.scales.maxs[2] - json.scales.mins[2]) * (i / 255),
- )
- .map((x) => Math.exp(x));
- }
-
- for (let i = 0; i < numSplats; ++i) {
- const i4 = i * 4;
- setPackedSplatScales(
- packedArray,
- i,
- xLookup[scales[i4 + 0]],
- yLookup[scales[i4 + 1]],
- zLookup[scales[i4 + 2]],
- splatEncoding,
- );
- }
- },
- );
-
- const quatsPromise = decodeImageRgba(extraFiles[json.quats.files[0]]).then(
- (quats) => {
- const SQRT2 = Math.sqrt(2);
- const lookup = new Array(256)
- .fill(0)
- .map((_, i) => (i / 255 - 0.5) * SQRT2);
-
- for (let i = 0; i < numSplats; ++i) {
- const i4 = i * 4;
- const r0 = lookup[quats[i4 + 0]];
- const r1 = lookup[quats[i4 + 1]];
- const r2 = lookup[quats[i4 + 2]];
- const rr = Math.sqrt(Math.max(0, 1.0 - r0 * r0 - r1 * r1 - r2 * r2));
- const rOrder = quats[i4 + 3] - 252;
- const quatX = rOrder === 0 ? r0 : rOrder === 1 ? rr : r1;
- const quatY = rOrder <= 1 ? r1 : rOrder === 2 ? rr : r2;
- const quatZ = rOrder <= 2 ? r2 : rr;
- const quatW = rOrder === 0 ? rr : r0;
- setPackedSplatQuat(packedArray, i, quatX, quatY, quatZ, quatW);
- }
- },
- );
- const sh0Promise = decodeImageRgba(extraFiles[json.sh0.files[0]]).then(
- (sh0) => {
- const SH_C0 = 0.28209479177387814;
- let rLookup: number[];
- let gLookup: number[];
- let bLookup: number[];
- let aLookup: number[];
-
- if (isVersion2) {
- rLookup =
- gLookup =
- bLookup =
- json.sh0.codebook.map((x) => SH_C0 * x + 0.5);
- aLookup = new Array(256).fill(0).map((_, i) => i / 255);
- } else {
- rLookup = new Array(256)
- .fill(0)
- .map(
- (_, i) =>
- json.sh0.mins[0] +
- (json.sh0.maxs[0] - json.sh0.mins[0]) * (i / 255),
- )
- .map((x) => SH_C0 * x + 0.5);
- gLookup = new Array(256)
- .fill(0)
- .map(
- (_, i) =>
- json.sh0.mins[1] +
- (json.sh0.maxs[1] - json.sh0.mins[1]) * (i / 255),
- )
- .map((x) => SH_C0 * x + 0.5);
- bLookup = new Array(256)
- .fill(0)
- .map(
- (_, i) =>
- json.sh0.mins[2] +
- (json.sh0.maxs[2] - json.sh0.mins[2]) * (i / 255),
- )
- .map((x) => SH_C0 * x + 0.5);
- aLookup = new Array(256)
- .fill(0)
- .map(
- (_, i) =>
- json.sh0.mins[3] +
- (json.sh0.maxs[3] - json.sh0.mins[3]) * (i / 255),
- )
- .map((x) => 1.0 / (1.0 + Math.exp(-x)));
- }
-
- for (let i = 0; i < numSplats; ++i) {
- const i4 = i * 4;
- setPackedSplatRgba(
- packedArray,
- i,
- rLookup[sh0[i4 + 0]],
- gLookup[sh0[i4 + 1]],
- bLookup[sh0[i4 + 2]],
- aLookup[sh0[i4 + 3]],
- splatEncoding,
- );
- }
- },
- );
-
- const promises = [meansPromise, scalesPromise, quatsPromise, sh0Promise];
- if (json.shN) {
- const useSH3 = isVersion2
- ? json.shN.bands >= 3
- : json.shN.shape[1] >= 48 - 3;
- const useSH2 = isVersion2
- ? json.shN.bands >= 2
- : json.shN.shape[1] >= 27 - 3;
- const useSH1 = isVersion2
- ? json.shN.bands >= 1
- : json.shN.shape[1] >= 12 - 3;
-
- if (useSH1) extra.sh1 = new Uint32Array(numSplats * 2);
- if (useSH2) extra.sh2 = new Uint32Array(numSplats * 4);
- if (useSH3) extra.sh3 = new Uint32Array(numSplats * 4);
-
- const sh1 = new Float32Array(9);
- const sh2 = new Float32Array(15);
- const sh3 = new Float32Array(21);
-
- const shN = json.shN;
- const shNPromise = Promise.all([
- decodeImage(extraFiles[json.shN.files[0]]),
- decodeImage(extraFiles[json.shN.files[1]]),
- ]).then(([centroids, labels]) => {
- const lookup =
- "codebook" in shN
- ? shN.codebook
- : new Array(256)
- .fill(0)
- .map((_, i) => shN.mins + (shN.maxs - shN.mins) * (i / 255));
-
- for (let i = 0; i < numSplats; ++i) {
- const i4 = i * 4;
- const label = labels.rgba[i4 + 0] + (labels.rgba[i4 + 1] << 8);
- const col = (label & 63) * 15;
- const row = label >>> 6;
- const offset = row * centroids.width + col;
-
- for (let d = 0; d < 3; ++d) {
- if (useSH1) {
- for (let k = 0; k < 3; ++k) {
- sh1[k * 3 + d] = lookup[centroids.rgba[(offset + k) * 4 + d]];
- }
- }
-
- if (useSH2) {
- for (let k = 0; k < 5; ++k) {
- sh2[k * 3 + d] = lookup[centroids.rgba[(offset + 3 + k) * 4 + d]];
- }
- }
-
- if (useSH3) {
- for (let k = 0; k < 7; ++k) {
- sh3[k * 3 + d] = lookup[centroids.rgba[(offset + 8 + k) * 4 + d]];
- }
- }
- }
-
- if (useSH1)
- encodeSh1Rgb(extra.sh1 as Uint32Array, i, sh1, splatEncoding);
- if (useSH2)
- encodeSh2Rgb(extra.sh2 as Uint32Array, i, sh2, splatEncoding);
- if (useSH3)
- encodeSh3Rgb(extra.sh3 as Uint32Array, i, sh3, splatEncoding);
- }
- });
- promises.push(shNPromise);
- }
-
- await Promise.all(promises);
-
- return { packedArray, numSplats, extra };
-}
-
-// WebGL context for reading raw pixel data of WebP images
-let offscreenGlContext: WebGL2RenderingContext | null = null;
-
-async function decodeImage(fileBytes: ArrayBuffer) {
- if (!offscreenGlContext) {
- const canvas = new OffscreenCanvas(1, 1);
- offscreenGlContext = canvas.getContext("webgl2");
- if (!offscreenGlContext) {
- throw new Error("Failed to create WebGL2 context");
- }
- }
-
- const imageBlob = new Blob([fileBytes]);
- const bitmap = await createImageBitmap(imageBlob, {
- premultiplyAlpha: "none",
- });
-
- const gl = offscreenGlContext;
- const texture = gl.createTexture();
- gl.bindTexture(gl.TEXTURE_2D, texture);
- gl.pixelStorei(gl.UNPACK_FLIP_Y_WEBGL, true);
- gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, bitmap);
- gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
- gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
-
- const framebuffer = gl.createFramebuffer();
- gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
- gl.framebufferTexture2D(
- gl.FRAMEBUFFER,
- gl.COLOR_ATTACHMENT0,
- gl.TEXTURE_2D,
- texture,
- 0,
- );
-
- const data = new Uint8Array(bitmap.width * bitmap.height * 4);
- gl.readPixels(
- 0,
- 0,
- bitmap.width,
- bitmap.height,
- gl.RGBA,
- gl.UNSIGNED_BYTE,
- data,
- );
-
- gl.deleteTexture(texture);
- gl.deleteFramebuffer(framebuffer);
-
- return { rgba: data, width: bitmap.width, height: bitmap.height };
-}
-
-async function decodeImageRgba(fileBytes: ArrayBuffer) {
- const { rgba } = await decodeImage(fileBytes);
- return rgba;
-}
-
-export async function unpackPcSogsZip(
- fileBytes: Uint8Array,
- splatEncoding: SplatEncoding,
-): Promise<{
- packedArray: Uint32Array;
- numSplats: number;
- extra: Record;
-}> {
- const nameJson = tryPcSogsZip(fileBytes);
- if (!nameJson) {
- throw new Error("Invalid PC SOGS zip file");
- }
- const { name, json } = nameJson;
- // Find path prefix, will be -1 if no / or \
- const lastSlash = name.lastIndexOf("/");
- const lastBackslash = name.lastIndexOf("\\");
- const prefix = name.slice(0, Math.max(lastSlash, lastBackslash) + 1);
-
- const fileMap = new Map();
- const refFiles = [
- ...json.means.files,
- ...json.scales.files,
- ...json.quats.files,
- ...json.sh0.files,
- ...(json.shN?.files ?? []),
- ];
- for (const file of refFiles) {
- fileMap.set(prefix + file, file);
- }
-
- const unzipped = await new Promise>(
- (resolve, reject) => {
- unzip(
- fileBytes,
- {
- filter: ({ name }) => {
- return fileMap.has(name);
- },
- },
- (err, files) => {
- if (err) {
- reject(err);
- } else {
- resolve(files);
- }
- },
- );
- },
- );
-
- const extraFiles: Record = {};
- for (const [full, name] of fileMap.entries()) {
- extraFiles[name] = unzipped[full];
- }
-
- return await unpackPcSogs(json, extraFiles, splatEncoding);
-}
diff --git a/src/ply.ts b/src/ply.ts
deleted file mode 100644
index b3ce24bd..00000000
--- a/src/ply.ts
+++ /dev/null
@@ -1,1091 +0,0 @@
-// PLY file format reader
-
-import { USE_COMPILED_PARSER_FUNCTION } from "./defines";
-
-const PLY_PROPERTY_TYPES = [
- "char",
- "uchar",
- "short",
- "ushort",
- "int",
- "uint",
- "float",
- "double",
-] as const;
-export type PlyPropertyType = (typeof PLY_PROPERTY_TYPES)[number];
-
-export type PlyElement = {
- name: string;
- count: number;
- properties: Record;
-};
-
-export type PlyProperty = {
- isList: boolean;
- type: PlyPropertyType;
- countType?: PlyPropertyType;
-};
-
-// Callback for parseSplats base Gsplat data
-export type SplatCallback = (
- index: number,
- x: number,
- y: number,
- z: number,
- scaleX: number,
- scaleY: number,
- scaleZ: number,
- quatX: number,
- quatY: number,
- quatZ: number,
- quatW: number,
- opacity: number,
- r: number,
- g: number,
- b: number,
-) => void;
-
-// Callback for parseSplats SH coefficients
-export type SplatShCallback = (
- index: number,
- sh1: Float32Array,
- sh2?: Float32Array,
- sh3?: Float32Array,
-) => void;
-
-// A PlyReader is used to parse PLY files for Gsplat data.
-// It takes a Uint8Array/ArrayBuffer as input fileBytes, parses the text header,
-// and provides a method parseData to iterate over the entire binary data
-// efficiently, or parseSplats to iterate over Gsplat data.
-
-export class PlyReader {
- fileBytes: Uint8Array;
- header = "";
- littleEndian = true;
- elements: Record = {};
- comments: string[] = [];
- data: DataView | null = null;
- static defaultPointScale = 0.001;
-
- numSplats = 0;
-
- // Create a PlyReader from a Uint8Array/ArrayBuffer, no parsing done yet
- constructor({ fileBytes }: { fileBytes: Uint8Array | ArrayBuffer }) {
- this.fileBytes =
- fileBytes instanceof ArrayBuffer ? new Uint8Array(fileBytes) : fileBytes;
- }
-
- // Identify and parse the PLY text header (assumed to be <64KB in size).
- // this.elements will contain all the elements in the file, typically
- // "vertex" contains the Gsplat data.
- async parseHeader() {
- const bufferStream = new ReadableStream({
- start: (
- controller: ReadableStreamController>,
- ) => {
- // Assume the header is less than 64KB
- controller.enqueue(this.fileBytes.slice(0, 65536));
- controller.close();
- },
- });
- const decoder = bufferStream
- .pipeThrough(new TextDecoderStream())
- .getReader();
-
- // Find the end of the text section of the PLY file
- this.header = "";
- const headerTerminator = "end_header\n";
- while (true) {
- const { value, done } = await decoder.read();
- if (done) {
- throw new Error("Failed to read header");
- }
-
- this.header += value as string;
- const endHeader = this.header.indexOf(headerTerminator);
- if (endHeader >= 0) {
- this.header = this.header.slice(0, endHeader + headerTerminator.length);
- break;
- }
- }
- // Partition the file into header and binary data
- const headerLen = new TextEncoder().encode(this.header).length;
- this.data = new DataView(this.fileBytes.buffer, headerLen);
-
- this.elements = {};
- let curElement: PlyElement | null = null;
- this.comments = [];
-
- this.header
- .trim()
- .split("\n")
- .forEach((line: string, lineIndex: number) => {
- const trimmedLine = line.trim();
- if (lineIndex === 0) {
- if (trimmedLine !== "ply") {
- throw new Error("Invalid PLY header");
- }
- return;
- }
- if (trimmedLine.length === 0) {
- return; // Skip empty lines
- }
-
- const fields = trimmedLine.split(" ");
- switch (fields[0]) {
- case "format":
- if (fields[1] === "binary_little_endian") {
- this.littleEndian = true;
- } else if (fields[1] === "binary_big_endian") {
- this.littleEndian = false;
- } else {
- // ascii formats not supported
- throw new Error(`Unsupported PLY format: ${fields[1]}`);
- }
- if (fields[2] !== "1.0") {
- throw new Error(`Unsupported PLY version: ${fields[2]}`);
- }
- break;
- case "end_header":
- break;
- case "comment":
- this.comments.push(trimmedLine.slice("comment ".length));
- break;
- case "element": {
- const name = fields[1];
- curElement = {
- name,
- count: Number.parseInt(fields[2]),
- properties: {},
- };
- this.elements[name] = curElement;
- break;
- }
- case "property":
- if (curElement == null) {
- throw new Error("Property must be inside an element");
- }
- if (fields[1] === "list") {
- curElement.properties[fields[4]] = {
- isList: true,
- type: fields[3] as PlyPropertyType,
- countType: fields[2] as PlyPropertyType,
- };
- } else {
- curElement.properties[fields[2]] = {
- isList: false,
- type: fields[1] as PlyPropertyType,
- };
- }
- break;
- default:
- // console.warn(`Skipping unsupported PLY keyword: ${fields[0]}`);
- }
- });
-
- if (this.elements.vertex) {
- this.numSplats = this.elements.vertex.count;
- }
- }
-
- parseData(
- elementCallback: (
- element: PlyElement,
- ) =>
- | null
- | ((index: number, item: Record) => void),
- ) {
- // Go through the entire binary data of the PLY file, starting at offset 0
- let offset = 0;
- const data = this.data;
- if (data == null) {
- throw new Error("No data to parse");
- }
-
- for (const elementName in this.elements) {
- const element = this.elements[elementName];
- const { count, properties } = element;
- const item = createEmptyItem(properties);
- // Construct a parse function
- const parseFn = createParseFn(properties, this.littleEndian);
-
- // Parse all the items in the element
- const callback = elementCallback(element) ?? (() => {});
- for (let index = 0; index < count; index++) {
- offset = parseFn(data, offset, item);
- callback(index, item);
- }
- }
- }
-
- // Parse all the Gsplat data in the PLY file in go, invoking the given
- // callbacks for each Gsplat.
- parseSplats(splatCallback: SplatCallback, shCallback?: SplatShCallback) {
- if (this.elements.vertex == null) {
- throw new Error("No vertex element found");
- }
-
- let isSuperSplat = false;
- const ssChunks: SSChunk[] = [];
-
- let numSh = 0;
- let sh1Props: number[] = [];
- let sh2Props: number[] = [];
- let sh3Props: number[] = [];
- let sh1: Float32Array | undefined = undefined;
- let sh2: Float32Array | undefined = undefined;
- let sh3: Float32Array | undefined = undefined;
-
- function prepareSh() {
- // Prepare SH coefficient names and arrays for numSh total SH levels
- const num_f_rest = NUM_SH_TO_NUM_F_REST[numSh];
- sh1Props = new Array(3)
- .fill(null)
- .flatMap((_, k) => [0, 1, 2].map((_, d) => k + (d * num_f_rest) / 3));
- sh2Props = new Array(5)
- .fill(null)
- .flatMap((_, k) =>
- [0, 1, 2].map((_, d) => 3 + k + (d * num_f_rest) / 3),
- );
- sh3Props = new Array(7)
- .fill(null)
- .flatMap((_, k) =>
- [0, 1, 2].map((_, d) => 8 + k + (d * num_f_rest) / 3),
- );
- sh1 = numSh >= 1 ? new Float32Array(3 * 3) : undefined;
- sh2 = numSh >= 2 ? new Float32Array(5 * 3) : undefined;
- sh3 = numSh >= 3 ? new Float32Array(7 * 3) : undefined;
- }
-
- function ssShCallback(
- index: number,
- item: Record,
- ) {
- // Decode SH for SuperSplat compressed data
- if (!sh1) {
- throw new Error("Missing sh1");
- }
- const sh = item.f_rest as number[];
-
- for (let i = 0; i < sh1Props.length; i++) {
- sh1[i] = (sh[sh1Props[i]] * 8) / 255 - 4;
- }
- if (sh2) {
- for (let i = 0; i < sh2Props.length; i++) {
- sh2[i] = (sh[sh2Props[i]] * 8) / 255 - 4;
- }
- }
- if (sh3) {
- for (let i = 0; i < sh3Props.length; i++) {
- sh3[i] = (sh[sh3Props[i]] * 8) / 255 - 4;
- }
- }
- shCallback?.(index, sh1, sh2, sh3);
- }
-
- function initSuperSplat(element: PlyElement) {
- const {
- min_x,
- min_y,
- min_z,
- max_x,
- max_y,
- max_z,
- min_scale_x,
- min_scale_y,
- min_scale_z,
- max_scale_x,
- max_scale_y,
- max_scale_z,
- } = element.properties;
- if (
- !min_x ||
- !min_y ||
- !min_z ||
- !max_x ||
- !max_y ||
- !max_z ||
- !min_scale_x ||
- !min_scale_y ||
- !min_scale_z ||
- !max_scale_x ||
- !max_scale_y ||
- !max_scale_z
- ) {
- throw new Error("Missing PLY chunk properties");
- }
-
- // SuperSplat chunks are used to quantize splat data, so we need to store them
- isSuperSplat = true;
- return (index: number, item: Record) => {
- const {
- min_x,
- min_y,
- min_z,
- max_x,
- max_y,
- max_z,
- min_scale_x,
- min_scale_y,
- min_scale_z,
- max_scale_x,
- max_scale_y,
- max_scale_z,
- min_r,
- min_g,
- min_b,
- max_r,
- max_g,
- max_b,
- } = item as Record;
- ssChunks.push({
- min_x,
- min_y,
- min_z,
- max_x,
- max_y,
- max_z,
- min_scale_x,
- min_scale_y,
- min_scale_z,
- max_scale_x,
- max_scale_y,
- max_scale_z,
- min_r,
- min_g,
- min_b,
- max_r,
- max_g,
- max_b,
- });
- };
- }
-
- function decodeSuperSplat(element: PlyElement) {
- // Decode SuperSplat compressed data in vertex and sh elements
- if (shCallback && element.name === "sh") {
- numSh = getNumSh(element.properties);
- prepareSh();
- return ssShCallback;
- }
- if (element.name !== "vertex") {
- return null;
- }
-
- const { packed_position, packed_rotation, packed_scale, packed_color } =
- element.properties;
- if (
- !packed_position ||
- !packed_rotation ||
- !packed_scale ||
- !packed_color
- ) {
- throw new Error(
- "Missing PLY properties: packed_position, packed_rotation, packed_scale, packed_color",
- );
- }
-
- const SQRT2 = Math.sqrt(2);
-
- return (index: number, item: Record) => {
- // SuperSplat data are quantized within chunks with 256 Gsplats each
- const chunk = ssChunks[index >>> 8];
- if (chunk == null) {
- throw new Error("Missing PLY chunk");
- }
- const {
- min_x,
- min_y,
- min_z,
- max_x,
- max_y,
- max_z,
- min_scale_x,
- min_scale_y,
- min_scale_z,
- max_scale_x,
- max_scale_y,
- max_scale_z,
- min_r,
- min_g,
- min_b,
- max_r,
- max_g,
- max_b,
- } = chunk;
- const { packed_position, packed_rotation, packed_scale, packed_color } =
- item as Record;
-
- const x =
- (((packed_position >>> 21) & 2047) / 2047) * (max_x - min_x) + min_x;
- const y =
- (((packed_position >>> 11) & 1023) / 1023) * (max_y - min_y) + min_y;
- const z = ((packed_position & 2047) / 2047) * (max_z - min_z) + min_z;
-
- const r0 = (((packed_rotation >>> 20) & 1023) / 1023 - 0.5) * SQRT2;
- const r1 = (((packed_rotation >>> 10) & 1023) / 1023 - 0.5) * SQRT2;
- const r2 = ((packed_rotation & 1023) / 1023 - 0.5) * SQRT2;
- const rr = Math.sqrt(Math.max(0, 1.0 - r0 * r0 - r1 * r1 - r2 * r2));
-
- const rOrder = packed_rotation >>> 30;
- const quatX = rOrder === 0 ? r0 : rOrder === 1 ? rr : r1;
- const quatY = rOrder <= 1 ? r1 : rOrder === 2 ? rr : r2;
- const quatZ = rOrder <= 2 ? r2 : rr;
- const quatW = rOrder === 0 ? rr : r0;
-
- const scaleX = Math.exp(
- (((packed_scale >>> 21) & 2047) / 2047) *
- (max_scale_x - min_scale_x) +
- min_scale_x,
- );
- const scaleY = Math.exp(
- (((packed_scale >>> 11) & 1023) / 1023) *
- (max_scale_y - min_scale_y) +
- min_scale_y,
- );
- const scaleZ = Math.exp(
- ((packed_scale & 2047) / 2047) * (max_scale_z - min_scale_z) +
- min_scale_z,
- );
-
- const r =
- (((packed_color >>> 24) & 255) / 255) *
- ((max_r ?? 1) - (min_r ?? 0)) +
- (min_r ?? 0);
- const g =
- (((packed_color >>> 16) & 255) / 255) *
- ((max_g ?? 1) - (min_g ?? 0)) +
- (min_g ?? 0);
- const b =
- (((packed_color >>> 8) & 255) / 255) * ((max_b ?? 1) - (min_b ?? 0)) +
- (min_b ?? 0);
- const opacity = (packed_color & 255) / 255;
-
- splatCallback(
- index,
- x,
- y,
- z,
- scaleX,
- scaleY,
- scaleZ,
- quatX,
- quatY,
- quatZ,
- quatW,
- opacity,
- r,
- g,
- b,
- );
- };
- }
-
- const elementCallback = (element: PlyElement) => {
- if (element.name === "chunk") {
- // "chunk" could conceivably be used for other formats, and we would
- // ideally check for the comment: Generated by SuperSplat 2.*
- // but gsplat also outputs this format without such a comment.
- // In order to support both, let's assume a "chunk" element should
- // be interpreted as this format.
- return initSuperSplat(element);
- }
- if (isSuperSplat) {
- return decodeSuperSplat(element);
- }
-
- if (element.name !== "vertex") {
- return null;
- }
-
- const {
- x,
- y,
- z,
- scale_0,
- scale_1,
- scale_2,
- rot_0,
- rot_1,
- rot_2,
- rot_3,
- opacity,
- f_dc_0,
- f_dc_1,
- f_dc_2,
- red,
- green,
- blue,
- alpha,
- } = element.properties;
-
- if (!x || !y || !z) {
- throw new Error("Missing PLY properties: x, y, z");
- }
- // Pure point cloud PLY files have no scales or rotations
- const hasScales = scale_0 && scale_1 && scale_2;
- const hasRots = rot_0 && rot_1 && rot_2 && rot_3;
- // Quantization scale factor for argb values
- const alphaDiv = alpha != null ? FIELD_SCALE[alpha.type] : 1;
- const redDiv = red != null ? FIELD_SCALE[red.type] : 1;
- const greenDiv = green != null ? FIELD_SCALE[green.type] : 1;
- const blueDiv = blue != null ? FIELD_SCALE[blue.type] : 1;
-
- numSh = getNumSh(element.properties);
- prepareSh();
-
- return (index: number, item: Record) => {
- const scaleX = hasScales
- ? Math.exp(item.scale_0 as number)
- : PlyReader.defaultPointScale;
- const scaleY = hasScales
- ? Math.exp(item.scale_1 as number)
- : PlyReader.defaultPointScale;
- const scaleZ = hasScales
- ? Math.exp(item.scale_2 as number)
- : PlyReader.defaultPointScale;
-
- const quatX = hasRots ? (item.rot_1 as number) : 0;
- const quatY = hasRots ? (item.rot_2 as number) : 0;
- const quatZ = hasRots ? (item.rot_3 as number) : 0;
- const quatW = hasRots ? (item.rot_0 as number) : 1;
-
- const op =
- opacity != null
- ? 1.0 / (1.0 + Math.exp(-item.opacity as number))
- : alpha != null
- ? (item.alpha as number) / alphaDiv
- : 1.0;
- const r =
- f_dc_0 != null
- ? (item.f_dc_0 as number) * SH_C0 + 0.5
- : red != null
- ? (item.red as number) / redDiv
- : 1.0;
- const g =
- f_dc_1 != null
- ? (item.f_dc_1 as number) * SH_C0 + 0.5
- : green != null
- ? (item.green as number) / greenDiv
- : 1.0;
- const b =
- f_dc_2 != null
- ? (item.f_dc_2 as number) * SH_C0 + 0.5
- : blue != null
- ? (item.blue as number) / blueDiv
- : 1.0;
-
- splatCallback(
- index,
- item.x as number,
- item.y as number,
- item.z as number,
- scaleX,
- scaleY,
- scaleZ,
- quatX,
- quatY,
- quatZ,
- quatW,
- op,
- r,
- g,
- b,
- );
-
- if (shCallback && sh1) {
- const sh = item.f_rest as number[];
- if (sh1) {
- for (let i = 0; i < sh1Props.length; i++) {
- sh1[i] = sh[sh1Props[i]];
- }
- }
- if (sh2) {
- for (let i = 0; i < sh2Props.length; i++) {
- sh2[i] = sh[sh2Props[i]];
- }
- }
- if (sh3) {
- for (let i = 0; i < sh3Props.length; i++) {
- sh3[i] = sh[sh3Props[i]];
- }
- }
- shCallback(index, sh1, sh2, sh3);
- }
- };
- };
-
- this.parseData(elementCallback);
- }
-
- // Inject RGBA values into original PLY file, which can be used to modify
- // the color/opacity of the Gsplats and write out the modified PLY file.
- injectRgba(rgba: Uint8Array) {
- // Go through the entire binary data of the PLY file, starting at offset 0
- let offset = 0;
- const data = this.data;
- if (data == null) {
- throw new Error("No parsed data");
- }
- if (rgba.length !== this.numSplats * 4) {
- throw new Error("Invalid RGBA array length");
- }
-
- for (const elementName in this.elements) {
- const element = this.elements[elementName];
- const { count, properties } = element;
- const parsers = [];
-
- let rgbaOffset = 0;
- const isVertex = elementName === "vertex";
- if (isVertex) {
- for (const name of ["opacity", "f_dc_0", "f_dc_1", "f_dc_2"]) {
- if (!properties[name] || properties[name].type !== "float") {
- throw new Error(`Can't injectRgba due to property: ${name}`);
- }
- }
- }
-
- for (const [propertyName, property] of Object.entries(properties)) {
- if (!property.isList) {
- if (isVertex) {
- if (
- propertyName === "f_dc_0" ||
- propertyName === "f_dc_1" ||
- propertyName === "f_dc_2"
- ) {
- const component = Number.parseInt(
- propertyName.slice("f_dc_".length),
- );
- parsers.push(() => {
- // Inject DC coefficients
- const value =
- (rgba[rgbaOffset + component] / 255 - 0.5) / SH_C0;
- SET_FIELD[property.type](
- data,
- offset,
- this.littleEndian,
- value,
- );
- });
- } else if (propertyName === "opacity") {
- parsers.push(() => {
- // Inject opacity sigmoid, clamped to [-100, 100]
- const value = Math.max(
- -100,
- Math.min(
- 100,
- -Math.log(1.0 / (rgba[rgbaOffset + 3] / 255) - 1.0),
- ),
- );
- SET_FIELD[property.type](
- data,
- offset,
- this.littleEndian,
- value,
- );
- });
- }
- }
- parsers.push(() => {
- offset += FIELD_BYTES[property.type];
- });
- } else {
- parsers.push(() => {
- const length = PARSE_FIELD[property.countType as PlyPropertyType](
- data,
- offset,
- this.littleEndian,
- );
- offset += FIELD_BYTES[property.countType as PlyPropertyType];
- offset += length * FIELD_BYTES[property.type];
- });
- }
- }
-
- for (let index = 0; index < count; index++) {
- // Go through all the data and field parsers to compute offset
- for (const parser of parsers) {
- parser();
- }
- if (isVertex) {
- rgbaOffset += 4;
- }
- }
- }
- }
-}
-
-export const SH_C0 = 0.28209479177387814;
-
-type FieldParser = (
- data: DataView,
- offset: number,
- littleEndian: boolean,
-) => number;
-type FieldSetter = (
- data: DataView,
- offset: number,
- littleEndian: boolean,
- value: number,
-) => void;
-
-const PARSE_FIELD: Record = {
- char: (data: DataView, offset: number, littleEndian: boolean) => {
- return data.getInt8(offset);
- },
- uchar: (data: DataView, offset: number, littleEndian: boolean) => {
- return data.getUint8(offset);
- },
- short: (data: DataView, offset: number, littleEndian: boolean) => {
- return data.getInt16(offset, littleEndian);
- },
- ushort: (data: DataView, offset: number, littleEndian: boolean) => {
- return data.getUint16(offset, littleEndian);
- },
- int: (data: DataView, offset: number, littleEndian: boolean) => {
- return data.getInt32(offset, littleEndian);
- },
- uint: (data: DataView, offset: number, littleEndian: boolean) => {
- return data.getUint32(offset, littleEndian);
- },
- float: (data: DataView, offset: number, littleEndian: boolean) => {
- return data.getFloat32(offset, littleEndian);
- },
- double: (data: DataView, offset: number, littleEndian: boolean) => {
- return data.getFloat64(offset, littleEndian);
- },
-};
-
-const SET_FIELD: Record = {
- char: (
- data: DataView,
- offset: number,
- littleEndian: boolean,
- value: number,
- ) => {
- data.setInt8(offset, value);
- },
- uchar: (
- data: DataView,
- offset: number,
- littleEndian: boolean,
- value: number,
- ) => {
- data.setUint8(offset, value);
- },
- short: (
- data: DataView,
- offset: number,
- littleEndian: boolean,
- value: number,
- ) => {
- data.setInt16(offset, value, littleEndian);
- },
- ushort: (
- data: DataView,
- offset: number,
- littleEndian: boolean,
- value: number,
- ) => {
- data.setUint16(offset, value, littleEndian);
- },
- int: (
- data: DataView,
- offset: number,
- littleEndian: boolean,
- value: number,
- ) => {
- data.setInt32(offset, value, littleEndian);
- },
- uint: (
- data: DataView,
- offset: number,
- littleEndian: boolean,
- value: number,
- ) => {
- data.setUint32(offset, value, littleEndian);
- },
- float: (
- data: DataView,
- offset: number,
- littleEndian: boolean,
- value: number,
- ) => {
- data.setFloat32(offset, value, littleEndian);
- },
- double: (
- data: DataView,
- offset: number,
- littleEndian: boolean,
- value: number,
- ) => {
- data.setFloat64(offset, value, littleEndian);
- },
-};
-
-const FIELD_BYTES: Record = {
- char: 1,
- uchar: 1,
- short: 2,
- ushort: 2,
- int: 4,
- uint: 4,
- float: 4,
- double: 8,
-};
-
-const FIELD_SCALE: Record = {
- char: 127,
- uchar: 255,
- short: 32767,
- ushort: 65535,
- int: 2147483647,
- uint: 4294967295,
- float: 1,
- double: 1,
-};
-
-const NUM_F_REST_TO_NUM_SH: Record = {
- 0: 0,
- 9: 1,
- 24: 2,
- 45: 3,
-};
-const NUM_SH_TO_NUM_F_REST: Record = {
- 0: 0,
- 1: 9,
- 2: 24,
- 3: 45,
-};
-
-const F_REST_REGEX = /^f_rest_([0-9]{1,2})$/;
-
-function createEmptyItem(
- properties: Record,
-): Record {
- const item: Record = {};
- for (const [propertyName, property] of Object.entries(properties)) {
- // Treat f_rest properties as a single array for performance
- if (F_REST_REGEX.test(propertyName)) {
- item.f_rest = new Array(getNumSh(properties));
- } else {
- item[propertyName] = property.isList ? [] : 0;
- }
- }
- return item;
-}
-
-function createParseFn(
- properties: Record,
- littleEndian: boolean,
-) {
- if (USE_COMPILED_PARSER_FUNCTION && safeToCompile(properties)) {
- return createCompiledParserFn(properties, littleEndian);
- }
- return createDynamicParserFn(properties, littleEndian);
-}
-
-// Detect if unsafe eval is allowed in the current execution context
-const UNSAFE_EVAL_ALLOWED = (() => {
- try {
- new Function("return 42;");
- } catch (e) {
- return false;
- }
- return true;
-})();
-const PROPERTY_NAME_REGEX = /^[a-zA-Z0-9_]+$/;
-
-function safeToCompile(properties: Record) {
- if (!UNSAFE_EVAL_ALLOWED) {
- return false;
- }
-
- for (const [propertyName, property] of Object.entries(properties)) {
- if (!PROPERTY_NAME_REGEX.test(propertyName)) {
- return false;
- }
-
- if (
- property.isList &&
- !PLY_PROPERTY_TYPES.includes(property.countType as PlyPropertyType)
- ) {
- return false;
- }
-
- if (!PLY_PROPERTY_TYPES.includes(property.type)) {
- return false;
- }
- }
- return true;
-}
-
-function createCompiledParserFn(
- properties: Record,
- littleEndian: boolean,
-) {
- // Construct the parser function source.
- const parserSrc: string[] = ["let list;"];
- for (const [propertyName, property] of Object.entries(properties)) {
- const fRestMatch = propertyName.match(F_REST_REGEX);
- if (fRestMatch) {
- const fRestIndex = +fRestMatch[1];
- parserSrc.push(/*js*/ `
- item.f_rest[${fRestIndex}] = PARSE_FIELD['${property.type}'](data, offset, ${littleEndian});
- offset += ${FIELD_BYTES[property.type]};
- `);
- } else if (!property.isList) {
- parserSrc.push(/*js*/ `
- item['${propertyName}'] = PARSE_FIELD['${property.type}'](data, offset, ${littleEndian});
- offset += ${FIELD_BYTES[property.type]};
- `);
- } else {
- // Property is a list, so parse the count first
- parserSrc.push(/*js*/ `
- list = item['${propertyName}'];
- list.length = PARSE_FIELD['${property.countType}'](data, offset, ${littleEndian});
- offset += ${FIELD_BYTES[property.countType as PlyPropertyType]};
- for (let i = 0; i < list.length; i++) {
- list[i] = PARSE_FIELD['${property.type}'](data, offset, ${littleEndian});
- offset += ${FIELD_BYTES[property.type]};
- }
- `);
- }
- }
- parserSrc.push("return offset;");
-
- const fn = new Function(
- "data",
- "offset",
- "item",
- "PARSE_FIELD",
- parserSrc.join("\n"),
- );
- return (
- data: DataView,
- offset: number,
- item: Record,
- ) => fn(data, offset, item, PARSE_FIELD);
-}
-
-function createDynamicParserFn(
- properties: Record,
- littleEndian: boolean,
-) {
- // Construct an array of parser function to parse each property in an item
- const parsers: Array<
- (
- data: DataView,
- offset: number,
- item: Record,
- ) => number
- > = [];
- for (const [propertyName, property] of Object.entries(properties)) {
- const fRestMatch = propertyName.match(F_REST_REGEX);
- if (fRestMatch) {
- const fRestIndex = +fRestMatch[1];
- parsers.push(
- (
- data: DataView,
- offset: number,
- item: Record,
- ) => {
- (item.f_rest as number[])[fRestIndex] = PARSE_FIELD[property.type](
- data,
- offset,
- littleEndian,
- );
- return offset + FIELD_BYTES[property.type];
- },
- );
- } else if (!property.isList) {
- parsers.push(
- (
- data: DataView,
- offset: number,
- item: Record,
- ) => {
- item[propertyName] = PARSE_FIELD[property.type](
- data,
- offset,
- littleEndian,
- );
- return offset + FIELD_BYTES[property.type];
- },
- );
- } else {
- // Property is a list, so parse the count first
- parsers.push(
- (
- data: DataView,
- offset: number,
- item: Record,
- ) => {
- const list = item[propertyName] as number[];
- list.length = PARSE_FIELD[property.countType as PlyPropertyType](
- data,
- offset,
- littleEndian,
- );
- let currentOffset =
- offset + FIELD_BYTES[property.countType as PlyPropertyType];
- for (let i = 0; i < list.length; i++) {
- list[i] = PARSE_FIELD[property.type](
- data,
- currentOffset,
- littleEndian,
- );
- currentOffset += FIELD_BYTES[property.type];
- }
- return currentOffset;
- },
- );
- }
- }
-
- return (
- data: DataView,
- offset: number,
- item: Record,
- ) => {
- let currentOffset = offset;
- for (let parserIndex = 0; parserIndex < parsers.length; parserIndex++) {
- currentOffset = parsers[parserIndex](data, currentOffset, item);
- }
- return currentOffset;
- };
-}
-
-function getNumSh(properties: Record) {
- let num_f_rest = 0;
- while (properties[`f_rest_${num_f_rest}`]) {
- num_f_rest += 1;
- }
- const numSh = NUM_F_REST_TO_NUM_SH[num_f_rest];
- if (numSh == null) {
- throw new Error(`Unsupported number of SH coefficients: ${num_f_rest}`);
- }
- return numSh;
-}
-
-type SSChunk = {
- min_x: number;
- min_y: number;
- min_z: number;
- max_x: number;
- max_y: number;
- max_z: number;
- min_scale_x: number;
- min_scale_y: number;
- min_scale_z: number;
- max_scale_x: number;
- max_scale_y: number;
- max_scale_z: number;
- min_r?: number;
- min_g?: number;
- min_b?: number;
- max_r?: number;
- max_g?: number;
- max_b?: number;
-};
diff --git a/src/spz.ts b/src/spz.ts
index 75d3d832..ac898cbe 100644
--- a/src/spz.ts
+++ b/src/spz.ts
@@ -1,519 +1,14 @@
-import * as THREE from "three";
+import type { PackedSplats } from "./PackedSplats";
import {
- SplatData,
type TranscodeSpzInput,
getSplatFileType,
getSplatFileTypeFromPath,
} from "./SplatLoader";
-import { GunzipReader, fromHalf, normalize } from "./utils";
-import { decodeAntiSplat } from "./antisplat";
-import { SplatFileType } from "./defines";
-import { decodeKsplat } from "./ksplat";
-import { PlyReader } from "./ply";
-
-// SPZ file format reader
-
-export class SpzReader {
- fileBytes: Uint8Array;
- reader: GunzipReader;
-
- version = -1;
- numSplats = 0;
- shDegree = 0;
- fractionalBits = 0;
- flags = 0;
- flagAntiAlias = false;
- flagLod = false;
- reserved = 0;
- headerParsed = false;
- parsed = false;
-
- constructor({ fileBytes }: { fileBytes: Uint8Array | ArrayBuffer }) {
- this.fileBytes =
- fileBytes instanceof ArrayBuffer ? new Uint8Array(fileBytes) : fileBytes;
- this.reader = new GunzipReader({
- fileBytes: this.fileBytes as Uint8Array,
- });
- }
-
- async parseHeader() {
- if (this.headerParsed) {
- throw new Error("SPZ file header already parsed");
- }
-
- const header = new DataView((await this.reader.read(16)).buffer);
- if (header.getUint32(0, true) !== 0x5053474e) {
- throw new Error("Invalid SPZ file");
- }
- this.version = header.getUint32(4, true);
- if (this.version < 1 || this.version > 3) {
- throw new Error(`Unsupported SPZ version: ${this.version}`);
- }
-
- this.numSplats = header.getUint32(8, true);
- this.shDegree = header.getUint8(12);
- this.fractionalBits = header.getUint8(13);
- this.flags = header.getUint8(14);
- this.flagAntiAlias = (this.flags & 0x01) !== 0;
- this.flagLod = (this.flags & 0x80) !== 0;
- this.reserved = header.getUint8(15);
- this.headerParsed = true;
- this.parsed = false;
- }
-
- async parseSplats(
- centerCallback?: (index: number, x: number, y: number, z: number) => void,
- alphaCallback?: (index: number, alpha: number) => void,
- rgbCallback?: (index: number, r: number, g: number, b: number) => void,
- scalesCallback?: (
- index: number,
- scaleX: number,
- scaleY: number,
- scaleZ: number,
- ) => void,
- quatCallback?: (
- index: number,
- quatX: number,
- quatY: number,
- quatZ: number,
- quatW: number,
- ) => void,
- shCallback?: (
- index: number,
- sh1: Float32Array,
- sh2?: Float32Array,
- sh3?: Float32Array,
- ) => void,
- {
- childCounts,
- childStarts,
- }: {
- childCounts?: (index: number, count: number) => void;
- childStarts?: (index: number, start: number) => void;
- } = {},
- ) {
- if (!this.headerParsed) {
- throw new Error("SPZ file header must be parsed first");
- }
- if (this.parsed) {
- throw new Error("SPZ file already parsed");
- }
- this.parsed = true;
-
- if (this.version === 1) {
- // float16 centers
- const centerBytes = await this.reader.read(this.numSplats * 3 * 2);
- const centerUint16 = new Uint16Array(centerBytes.buffer);
- for (let i = 0; i < this.numSplats; i++) {
- const i3 = i * 3;
- const x = fromHalf(centerUint16[i3]);
- const y = fromHalf(centerUint16[i3 + 1]);
- const z = fromHalf(centerUint16[i3 + 2]);
- centerCallback?.(i, x, y, z);
- }
- } else if (this.version === 2 || this.version === 3) {
- // 24-bit fixed-point centers
- const fixed = 1 << this.fractionalBits;
- const centerBytes = await this.reader.read(this.numSplats * 3 * 3);
- for (let i = 0; i < this.numSplats; i++) {
- const i9 = i * 9;
- const x =
- (((centerBytes[i9 + 2] << 24) |
- (centerBytes[i9 + 1] << 16) |
- (centerBytes[i9] << 8)) >>
- 8) /
- fixed;
- const y =
- (((centerBytes[i9 + 5] << 24) |
- (centerBytes[i9 + 4] << 16) |
- (centerBytes[i9 + 3] << 8)) >>
- 8) /
- fixed;
- const z =
- (((centerBytes[i9 + 8] << 24) |
- (centerBytes[i9 + 7] << 16) |
- (centerBytes[i9 + 6] << 8)) >>
- 8) /
- fixed;
- centerCallback?.(i, x, y, z);
- }
- } else {
- throw new Error("Unreachable");
- }
-
- {
- const bytes = await this.reader.read(this.numSplats);
- for (let i = 0; i < this.numSplats; i++) {
- alphaCallback?.(i, bytes[i] / 255);
- }
- }
- {
- const rgbBytes = await this.reader.read(this.numSplats * 3);
- const scale = SH_C0 / 0.15;
- for (let i = 0; i < this.numSplats; i++) {
- const i3 = i * 3;
- const r = (rgbBytes[i3] / 255 - 0.5) * scale + 0.5;
- const g = (rgbBytes[i3 + 1] / 255 - 0.5) * scale + 0.5;
- const b = (rgbBytes[i3 + 2] / 255 - 0.5) * scale + 0.5;
- rgbCallback?.(i, r, g, b);
- }
- }
- {
- const scalesBytes = await this.reader.read(this.numSplats * 3);
- for (let i = 0; i < this.numSplats; i++) {
- const i3 = i * 3;
- const scaleX = Math.exp(scalesBytes[i3] / 16 - 10);
- const scaleY = Math.exp(scalesBytes[i3 + 1] / 16 - 10);
- const scaleZ = Math.exp(scalesBytes[i3 + 2] / 16 - 10);
- scalesCallback?.(i, scaleX, scaleY, scaleZ);
- }
- }
- if (this.version === 3) {
- // Version 3 uses a trick called "smallest three" to compress the rotation quaternions
- // achieving better precision. "Optimizing orientation" section at https://gafferongames.com/post/snapshot_compression/ A quaternion length must be 1: x^2+y^2+z^2+w^2 = 1
- // We can drop one component and reconstruct it with the identity above.
- // Largest component is dropped for best numerical precision.
- // Quaternion stored in 32 bits
- // 10 bits singed integer for each of the 3 components + 2 bits indicating the index of dropped component.
- // vs 8 bits for each component uncompressed (spz version < 3)
- // Max Value after extracting largest component v is another component v
- // (v,v,0,0)
- // v^2 + v^2 = 1
- // v = 1 / sqrt(2);
- const maxValue = 1 / Math.sqrt(2); // 0.7071
- const quatBytes = await this.reader.read(this.numSplats * 4);
- for (let i = 0; i < this.numSplats; i++) {
- const i3 = i * 4;
- const quaternion = [0, 0, 0, 0];
- const values = [
- quatBytes[i3],
- quatBytes[i3 + 1],
- quatBytes[i3 + 2],
- quatBytes[i3 + 3],
- ];
- // all values are packed in 32 bits (10 per each of 3 components + 2 bits of index of larged value)
- const combinedValues =
- values[0] + (values[1] << 8) + (values[2] << 16) + (values[3] << 24);
- // each component value is 9 bits + sign (1 bit)
- const valueMask = (1 << 9) - 1;
- // extract index of the largest element. 2 top bits.
- const largestIndex = combinedValues >>> 30;
- let remainingValues = combinedValues;
- let sumSquares = 0;
-
- for (let i = 3; i >= 0; --i) {
- if (i !== largestIndex) {
- // extract current value and sign.
- const value = remainingValues & valueMask;
- const sign = (remainingValues >>> 9) & 0x1;
- // each value is represented as 10 bits. Shift to next one.
- remainingValues = remainingValues >>> 10;
- // convert to range [0,1] and then to [0, 0.7071]
- quaternion[i] = maxValue * (value / valueMask);
- // apply sign.
- quaternion[i] = sign === 0 ? quaternion[i] : -quaternion[i];
- // accumulate the sum of squares
- sumSquares += quaternion[i] * quaternion[i];
- }
- }
-
- // quartenion length must be 1 (x^2+y^2+z^2+w^2 = 1)
- // so can reconstruct largest component from the other 3.
- // w = sqrt(1 - x^2 - y^2 - z^2);
- const square = 1 - sumSquares;
- quaternion[largestIndex] = Math.sqrt(Math.max(square, 0));
-
- quatCallback?.(
- i,
- quaternion[0],
- quaternion[1],
- quaternion[2],
- quaternion[3],
- );
- }
- } else {
- const quatBytes = await this.reader.read(this.numSplats * 3);
- for (let i = 0; i < this.numSplats; i++) {
- const i3 = i * 3;
- const quatX = quatBytes[i3] / 127.5 - 1;
- const quatY = quatBytes[i3 + 1] / 127.5 - 1;
- const quatZ = quatBytes[i3 + 2] / 127.5 - 1;
- const quatW = Math.sqrt(
- Math.max(0, 1 - quatX * quatX - quatY * quatY - quatZ * quatZ),
- );
- quatCallback?.(i, quatX, quatY, quatZ, quatW);
- }
- }
-
- if (shCallback && this.shDegree >= 1) {
- const sh1 = new Float32Array(3 * 3);
- const sh2 = this.shDegree >= 2 ? new Float32Array(5 * 3) : undefined;
- const sh3 = this.shDegree >= 3 ? new Float32Array(7 * 3) : undefined;
- const shBytes = await this.reader.read(
- this.numSplats * SH_DEGREE_TO_VECS[this.shDegree] * 3,
- );
-
- let offset = 0;
- for (let i = 0; i < this.numSplats; i++) {
- for (let j = 0; j < 9; ++j) {
- sh1[j] = (shBytes[offset + j] - 128) / 128;
- }
- offset += 9;
- if (sh2) {
- for (let j = 0; j < 15; ++j) {
- sh2[j] = (shBytes[offset + j] - 128) / 128;
- }
- offset += 15;
- }
- if (sh3) {
- for (let j = 0; j < 21; ++j) {
- sh3[j] = (shBytes[offset + j] - 128) / 128;
- }
- offset += 21;
- }
- shCallback?.(i, sh1, sh2, sh3);
- }
- }
- if (this.flagLod) {
- let bytes = await this.reader.read(this.numSplats * 2);
- for (let i = 0; i < this.numSplats; i++) {
- const i2 = i * 2;
- const count = bytes[i2] + (bytes[i2 + 1] << 8);
- childCounts?.(i, count);
- }
-
- bytes = await this.reader.read(this.numSplats * 4);
- for (let i = 0; i < this.numSplats; i++) {
- const i4 = i * 4;
- const start =
- bytes[i4] +
- (bytes[i4 + 1] << 8) +
- (bytes[i4 + 2] << 16) +
- (bytes[i4 + 3] << 24);
- childStarts?.(i, start);
- }
- }
- }
-}
-
-const SH_DEGREE_TO_VECS: Record = { 1: 3, 2: 8, 3: 15 };
-const SH_C0 = 0.28209479177387814;
-
-export const SPZ_MAGIC = 0x5053474e; // NGSP = Niantic gaussian splat
-export const SPZ_VERSION = 3;
-export const FLAG_ANTIALIASED = 0x1;
-
-export class SpzWriter {
- buffer: ArrayBuffer;
- view: DataView;
- numSplats: number;
- shDegree: number;
- fractionalBits: number;
- fraction: number;
- flagAntiAlias: boolean;
- clippedCount = 0;
-
- constructor({
- numSplats,
- shDegree,
- fractionalBits = 12,
- flagAntiAlias = true,
- }: {
- numSplats: number;
- shDegree: number;
- fractionalBits?: number;
- flagAntiAlias?: boolean;
- }) {
- const splatSize =
- 9 + // Position
- 1 + // Opacity
- 3 + // Scale
- 3 + // DC-rgb
- 4 + // Rotation
- (shDegree >= 1 ? 9 : 0) +
- (shDegree >= 2 ? 15 : 0) +
- (shDegree >= 3 ? 21 : 0);
- const bufferSize = 16 + numSplats * splatSize;
- this.buffer = new ArrayBuffer(bufferSize);
- this.view = new DataView(this.buffer);
-
- this.view.setUint32(0, SPZ_MAGIC, true); // NGSP
- this.view.setUint32(4, SPZ_VERSION, true);
- this.view.setUint32(8, numSplats, true);
- this.view.setUint8(12, shDegree);
- this.view.setUint8(13, fractionalBits);
- this.view.setUint8(14, flagAntiAlias ? FLAG_ANTIALIASED : 0);
- this.view.setUint8(15, 0); // Reserved
-
- this.numSplats = numSplats;
- this.shDegree = shDegree;
- this.fractionalBits = fractionalBits;
- this.fraction = 1 << fractionalBits;
- this.flagAntiAlias = flagAntiAlias;
- }
-
- setCenter(index: number, x: number, y: number, z: number) {
- // Divide by this.fraction and round to nearest integer,
- // then write as 3-bytes per x then y then z.
- const xRounded = Math.round(x * this.fraction);
- const xInt = Math.max(-0x7fffff, Math.min(0x7fffff, xRounded));
- const yRounded = Math.round(y * this.fraction);
- const yInt = Math.max(-0x7fffff, Math.min(0x7fffff, yRounded));
- const zRounded = Math.round(z * this.fraction);
- const zInt = Math.max(-0x7fffff, Math.min(0x7fffff, zRounded));
- const clipped = xRounded !== xInt || yRounded !== yInt || zRounded !== zInt;
- if (clipped) {
- this.clippedCount += 1;
- // if (this.clippedCount < 10) {
- // // Write x y z also in hex
- // console.log(`Clipped ${index}: ${x}, ${y}, ${z} (0x${x.toString(16)}, 0x${y.toString(16)}, 0x${z.toString(16)}) -> ${xRounded}, ${yRounded}, ${zRounded} (0x${xRounded.toString(16)}, 0x${yRounded.toString(16)}, 0x${zRounded.toString(16)}) -> ${xInt}, ${yInt}, ${zInt} (0x${xInt.toString(16)}, 0x${yInt.toString(16)}, 0x${zInt.toString(16)})`);
- // }
- }
- const i9 = index * 9;
- const base = 16 + i9;
- this.view.setUint8(base, xInt & 0xff);
- this.view.setUint8(base + 1, (xInt >> 8) & 0xff);
- this.view.setUint8(base + 2, (xInt >> 16) & 0xff);
- this.view.setUint8(base + 3, yInt & 0xff);
- this.view.setUint8(base + 4, (yInt >> 8) & 0xff);
- this.view.setUint8(base + 5, (yInt >> 16) & 0xff);
- this.view.setUint8(base + 6, zInt & 0xff);
- this.view.setUint8(base + 7, (zInt >> 8) & 0xff);
- this.view.setUint8(base + 8, (zInt >> 16) & 0xff);
- }
-
- setAlpha(index: number, alpha: number) {
- const base = 16 + this.numSplats * 9 + index;
- this.view.setUint8(
- base,
- Math.max(0, Math.min(255, Math.round(alpha * 255))),
- );
- }
-
- static scaleRgb(r: number) {
- const v = ((r - 0.5) / (SH_C0 / 0.15) + 0.5) * 255;
- return Math.max(0, Math.min(255, Math.round(v)));
- }
-
- setRgb(index: number, r: number, g: number, b: number) {
- const base = 16 + this.numSplats * 10 + index * 3;
- this.view.setUint8(base, SpzWriter.scaleRgb(r));
- this.view.setUint8(base + 1, SpzWriter.scaleRgb(g));
- this.view.setUint8(base + 2, SpzWriter.scaleRgb(b));
- }
-
- setScale(index: number, scaleX: number, scaleY: number, scaleZ: number) {
- const base = 16 + this.numSplats * 13 + index * 3;
- this.view.setUint8(
- base,
- Math.max(0, Math.min(255, Math.round((Math.log(scaleX) + 10) * 16))),
- );
- this.view.setUint8(
- base + 1,
- Math.max(0, Math.min(255, Math.round((Math.log(scaleY) + 10) * 16))),
- );
- this.view.setUint8(
- base + 2,
- Math.max(0, Math.min(255, Math.round((Math.log(scaleZ) + 10) * 16))),
- );
- }
-
- setQuat(
- index: number,
- ...q: [number, number, number, number] // x, y, z, w
- ) {
- const base = 16 + this.numSplats * 16 + index * 4;
-
- const quat = normalize(q);
-
- // Find largest component
- let iLargest = 0;
- for (let i = 1; i < 4; ++i) {
- if (Math.abs(quat[i]) > Math.abs(quat[iLargest])) {
- iLargest = i;
- }
- }
-
- // Since -quat represents the same rotation as quat, transform the quaternion so the largest element
- // is positive. This avoids having to send its sign bit.
- const negate = quat[iLargest] < 0 ? 1 : 0;
-
- // Do compression using sign bit and 9-bit precision per element.
- let comp = iLargest;
- for (let i = 0; i < 4; ++i) {
- if (i !== iLargest) {
- const negbit = (quat[i] < 0 ? 1 : 0) ^ negate;
- const mag = Math.floor(
- ((1 << 9) - 1) * (Math.abs(quat[i]) / Math.SQRT1_2) + 0.5,
- );
- comp = (comp << 10) | (negbit << 9) | mag;
- }
- }
-
- this.view.setUint8(base, comp & 0xff);
- this.view.setUint8(base + 1, (comp >> 8) & 0xff);
- this.view.setUint8(base + 2, (comp >> 16) & 0xff);
- this.view.setUint8(base + 3, (comp >>> 24) & 0xff);
- }
-
- static quantizeSh(sh: number, bits: number) {
- const value = Math.round(sh * 128) + 128;
- const bucketSize = 1 << (8 - bits);
- const quantized =
- Math.floor((value + bucketSize / 2) / bucketSize) * bucketSize;
- return Math.max(0, Math.min(255, quantized));
- }
-
- setSh(
- index: number,
- sh1: Float32Array,
- sh2?: Float32Array,
- sh3?: Float32Array,
- ) {
- const shVecs = SH_DEGREE_TO_VECS[this.shDegree] || 0;
- const base1 = 16 + this.numSplats * 20 + index * shVecs * 3;
- for (let j = 0; j < 9; ++j) {
- this.view.setUint8(base1 + j, SpzWriter.quantizeSh(sh1[j], 5));
- }
- if (sh2) {
- const base2 = base1 + 9;
- for (let j = 0; j < 15; ++j) {
- this.view.setUint8(base2 + j, SpzWriter.quantizeSh(sh2[j], 4));
- }
- if (sh3) {
- const base3 = base2 + 15;
- for (let j = 0; j < 21; ++j) {
- this.view.setUint8(base3 + j, SpzWriter.quantizeSh(sh3[j], 4));
- }
- }
- }
- }
-
- async finalize(): Promise {
- const input = new Uint8Array(this.buffer);
- const stream = new ReadableStream({
- async start(controller) {
- controller.enqueue(input);
- controller.close();
- },
- });
- const compressed = stream.pipeThrough(new CompressionStream("gzip"));
- const response = new Response(compressed);
- const buffer = await response.arrayBuffer();
- console.log(
- "Compressed",
- input.length,
- "bytes to",
- buffer.byteLength,
- "bytes",
- );
- return new Uint8Array(buffer);
- }
-}
+import { decode_to_gsplatarray, packedsplats_to_gsplatarray } from "spark-rs";
export async function transcodeSpz(input: TranscodeSpzInput) {
- const splats = new SplatData();
+ const splatArrays = [];
const {
inputs,
clipXyz,
@@ -523,45 +18,9 @@ export async function transcodeSpz(input: TranscodeSpzInput) {
} = input;
for (const input of inputs) {
const scale = input.transform?.scale ?? 1;
- const quaternion = new THREE.Quaternion().fromArray(
- input.transform?.quaternion ?? [0, 0, 0, 1],
- );
- const translate = new THREE.Vector3().fromArray(
- input.transform?.translate ?? [0, 0, 0],
- );
- const clip = clipXyz
- ? new THREE.Box3(
- new THREE.Vector3().fromArray(clipXyz.min),
- new THREE.Vector3().fromArray(clipXyz.max),
- )
- : undefined;
-
- function transformPos(pos: THREE.Vector3) {
- pos.multiplyScalar(scale);
- pos.applyQuaternion(quaternion);
- pos.add(translate);
- return pos;
- }
-
- function transformScales(scales: THREE.Vector3) {
- scales.multiplyScalar(scale);
- return scales;
- }
-
- function transformQuaternion(quat: THREE.Quaternion) {
- quat.premultiply(quaternion);
- return quat;
- }
-
- function withinClip(p: THREE.Vector3) {
- return !clip || clip.containsPoint(p);
- }
-
- function withinOpacity(opacity: number) {
- return opacityThreshold !== undefined
- ? opacity >= opacityThreshold
- : true;
- }
+ const quaternion = input.transform?.quaternion ?? [0, 0, 0, 1];
+ const translate = input.transform?.translate ?? [0, 0, 0];
+ const clip = clipXyz ? [...clipXyz.min, ...clipXyz.max] : undefined;
let fileType = input.fileType;
if (!fileType) {
@@ -570,294 +29,52 @@ export async function transcodeSpz(input: TranscodeSpzInput) {
fileType = getSplatFileTypeFromPath(input.pathOrUrl);
}
}
- switch (fileType) {
- case SplatFileType.PLY: {
- const ply = new PlyReader({ fileBytes: input.fileBytes });
- await ply.parseHeader();
- let lastIndex: number | null = null;
- ply.parseSplats(
- (
- index,
- x,
- y,
- z,
- scaleX,
- scaleY,
- scaleZ,
- quatX,
- quatY,
- quatZ,
- quatW,
- opacity,
- r,
- g,
- b,
- ) => {
- const center = transformPos(new THREE.Vector3(x, y, z));
- if (withinClip(center) && withinOpacity(opacity)) {
- lastIndex = splats.pushSplat();
- splats.setCenter(lastIndex, center.x, center.y, center.z);
- const scales = transformScales(
- new THREE.Vector3(scaleX, scaleY, scaleZ),
- );
- splats.setScale(lastIndex, scales.x, scales.y, scales.z);
- const quaternion = transformQuaternion(
- new THREE.Quaternion(quatX, quatY, quatZ, quatW),
- );
- splats.setQuaternion(
- lastIndex,
- quaternion.x,
- quaternion.y,
- quaternion.z,
- quaternion.w,
- );
- splats.setOpacity(lastIndex, opacity);
- splats.setColor(lastIndex, r, g, b);
- } else {
- lastIndex = null;
- }
- },
- (index, sh1, sh2, sh3) => {
- if (sh1 && lastIndex !== null) {
- splats.setSh1(lastIndex, sh1);
- }
- if (sh2 && lastIndex !== null) {
- splats.setSh2(lastIndex, sh2);
- }
- if (sh3 && lastIndex !== null) {
- splats.setSh3(lastIndex, sh3);
- }
- },
- );
- break;
- }
- case SplatFileType.SPZ: {
- const spz = new SpzReader({ fileBytes: input.fileBytes });
- await spz.parseHeader();
- const mapping = new Int32Array(spz.numSplats);
- mapping.fill(-1);
- const centers = new Float32Array(spz.numSplats * 3);
- const center = new THREE.Vector3();
- spz.parseSplats(
- (index, x, y, z) => {
- const center = transformPos(new THREE.Vector3(x, y, z));
- centers[index * 3] = center.x;
- centers[index * 3 + 1] = center.y;
- centers[index * 3 + 2] = center.z;
- },
- (index, alpha) => {
- center.fromArray(centers, index * 3);
- if (withinClip(center) && withinOpacity(alpha)) {
- mapping[index] = splats.pushSplat();
- splats.setCenter(mapping[index], center.x, center.y, center.z);
- splats.setOpacity(mapping[index], alpha);
- }
- },
- (index, r, g, b) => {
- if (mapping[index] >= 0) {
- splats.setColor(mapping[index], r, g, b);
- }
- },
- (index, scaleX, scaleY, scaleZ) => {
- if (mapping[index] >= 0) {
- const scales = transformScales(
- new THREE.Vector3(scaleX, scaleY, scaleZ),
- );
- splats.setScale(mapping[index], scales.x, scales.y, scales.z);
- }
- },
- (index, quatX, quatY, quatZ, quatW) => {
- if (mapping[index] >= 0) {
- const quaternion = transformQuaternion(
- new THREE.Quaternion(quatX, quatY, quatZ, quatW),
- );
- splats.setQuaternion(
- mapping[index],
- quaternion.x,
- quaternion.y,
- quaternion.z,
- quaternion.w,
- );
- }
- },
- (index, sh1, sh2, sh3) => {
- if (mapping[index] >= 0) {
- splats.setSh1(mapping[index], sh1);
- if (sh2) {
- splats.setSh2(mapping[index], sh2);
- }
- if (sh3) {
- splats.setSh3(mapping[index], sh3);
- }
- }
- },
- );
- break;
- }
- case SplatFileType.SPLAT:
- decodeAntiSplat(
- input.fileBytes,
- (numSplats) => {},
- (
- index,
- x,
- y,
- z,
- scaleX,
- scaleY,
- scaleZ,
- quatX,
- quatY,
- quatZ,
- quatW,
- opacity,
- r,
- g,
- b,
- ) => {
- const center = transformPos(new THREE.Vector3(x, y, z));
- if (withinClip(center) && withinOpacity(opacity)) {
- const index = splats.pushSplat();
- splats.setCenter(index, center.x, center.y, center.z);
- const scales = transformScales(
- new THREE.Vector3(scaleX, scaleY, scaleZ),
- );
- splats.setScale(index, scales.x, scales.y, scales.z);
- const quaternion = transformQuaternion(
- new THREE.Quaternion(quatX, quatY, quatZ, quatW),
- );
- splats.setQuaternion(
- index,
- quaternion.x,
- quaternion.y,
- quaternion.z,
- quaternion.w,
- );
- splats.setOpacity(index, opacity);
- splats.setColor(index, r, g, b);
- }
- },
- );
- break;
- case SplatFileType.KSPLAT: {
- let lastIndex: number | null = null;
- decodeKsplat(
- input.fileBytes,
- (numSplats) => {},
- (
- index,
- x,
- y,
- z,
- scaleX,
- scaleY,
- scaleZ,
- quatX,
- quatY,
- quatZ,
- quatW,
- opacity,
- r,
- g,
- b,
- ) => {
- const center = transformPos(new THREE.Vector3(x, y, z));
- if (withinClip(center) && withinOpacity(opacity)) {
- lastIndex = splats.pushSplat();
- splats.setCenter(lastIndex, center.x, center.y, center.z);
- const scales = transformScales(
- new THREE.Vector3(scaleX, scaleY, scaleZ),
- );
- splats.setScale(lastIndex, scales.x, scales.y, scales.z);
- const quaternion = transformQuaternion(
- new THREE.Quaternion(quatX, quatY, quatZ, quatW),
- );
- splats.setQuaternion(
- lastIndex,
- quaternion.x,
- quaternion.y,
- quaternion.z,
- quaternion.w,
- );
- splats.setOpacity(lastIndex, opacity);
- splats.setColor(lastIndex, r, g, b);
- } else {
- lastIndex = null;
- }
- },
- (index, sh1, sh2, sh3) => {
- if (lastIndex !== null) {
- splats.setSh1(lastIndex, sh1);
- if (sh2) {
- splats.setSh2(lastIndex, sh2);
- }
- if (sh3) {
- splats.setSh3(lastIndex, sh3);
- }
- }
- },
- );
- break;
- }
- default:
- throw new Error(`transcodeSpz not implemented for ${fileType}`);
+ const decoder = decode_to_gsplatarray(fileType, input.pathOrUrl);
+ const fileBytes = input.fileBytes;
+ const CHUNK_SIZE = 1048576; // 1 MB
+ for (let i = 0; i < fileBytes.length; i += CHUNK_SIZE) {
+ decoder.push(
+ fileBytes.subarray(i, Math.min(i + CHUNK_SIZE, fileBytes.length)),
+ );
}
- }
+ const decoded = decoder.finish();
- const shDegree = Math.min(
- maxSh ?? 3,
- splats.sh3 ? 3 : splats.sh2 ? 2 : splats.sh1 ? 1 : 0,
- );
- const spz = new SpzWriter({
- numSplats: splats.numSplats,
- shDegree,
- fractionalBits,
- flagAntiAlias: true,
- });
+ decoded.transform({
+ translation: translate,
+ rotation: quaternion,
+ scale,
+ clip,
+ opacityThreshold: opacityThreshold ?? 0,
+ });
- for (let i = 0; i < splats.numSplats; ++i) {
- const i3 = i * 3;
- const i4 = i * 4;
- spz.setCenter(
- i,
- splats.centers[i3],
- splats.centers[i3 + 1],
- splats.centers[i3 + 2],
- );
- spz.setScale(
- i,
- splats.scales[i3],
- splats.scales[i3 + 1],
- splats.scales[i3 + 2],
- );
- spz.setQuat(
- i,
- splats.quaternions[i4],
- splats.quaternions[i4 + 1],
- splats.quaternions[i4 + 2],
- splats.quaternions[i4 + 3],
- );
- spz.setAlpha(i, splats.opacities[i]);
- spz.setRgb(
- i,
- splats.colors[i3],
- splats.colors[i3 + 1],
- splats.colors[i3 + 2],
- );
- if (splats.sh1 && shDegree >= 1) {
- spz.setSh(
- i,
- splats.sh1.slice(i * 9, (i + 1) * 9),
- shDegree >= 2 && splats.sh2
- ? splats.sh2.slice(i * 15, (i + 1) * 15)
- : undefined,
- shDegree >= 3 && splats.sh3
- ? splats.sh3.slice(i * 21, (i + 1) * 21)
- : undefined,
- );
- }
+ splatArrays.push(decoded);
}
- const spzBytes = await spz.finalize();
- return { fileBytes: spzBytes, clippedCount: spz.clippedCount };
+ // Combine decoded splat arrays
+ const finalSplats = splatArrays[0];
+ for (let i = 1; i < splatArrays.length; i++) {
+ finalSplats.concat(splatArrays[i]);
+ }
+
+ const spzBytes = finalSplats.encode_to_spz(maxSh ?? 3, fractionalBits);
+
+ return { fileBytes: spzBytes, clippedCount: 0 };
+}
+
+export function writeSpz(
+ packedSplats: PackedSplats,
+ maxSh?: number,
+ fractionalBits?: number,
+) {
+ if (!packedSplats.packedArray) {
+ throw new Error("");
+ }
+ const gsplats = packedsplats_to_gsplatarray(
+ packedSplats.numSplats,
+ packedSplats.packedArray,
+ packedSplats.extra,
+ packedSplats.splatEncoding,
+ );
+ const spzBytes = gsplats.encode_to_spz(maxSh ?? 3, fractionalBits ?? 12);
+ return { fileBytes: spzBytes };
}
diff --git a/src/worker.ts b/src/worker.ts
index 1801a758..a6825d4e 100644
--- a/src/worker.ts
+++ b/src/worker.ts
@@ -96,35 +96,6 @@ function sortSplats32({
return { activeSplats, readback, ordering };
}
-async function fetchRange({
- url,
- requestHeader,
- withCredentials,
- offset,
- bytes,
-}: {
- url: string;
- requestHeader?: Record;
- withCredentials?: string;
- offset?: number;
- bytes?: number;
-}): Promise {
- const request = new Request(url, {
- headers: requestHeader ? new Headers(requestHeader) : undefined,
- credentials: withCredentials ? "include" : "same-origin",
- });
- if (offset !== undefined && bytes !== undefined) {
- request.headers.set("Range", `bytes=${offset}-${offset + bytes - 1}`);
- }
- const response = await fetch(request);
- if (!response.ok || !response.body) {
- throw new Error(
- `Failed to fetch "${url}": ${response.status} ${response.statusText}`,
- );
- }
- return new Uint8Array(await response.arrayBuffer());
-}
-
async function decodeBytesUrl({
decoder,
fileBytes,