diff --git a/examples/editor/index.html b/examples/editor/index.html index 36526634..a3d19243 100644 --- a/examples/editor/index.html +++ b/examples/editor/index.html @@ -69,7 +69,7 @@ import * as THREE from "three"; import { OrbitControls } from "three/addons/controls/OrbitControls.js"; import { GUI } from "lil-gui"; - import { constructGrid, SparkControls, SparkRenderer, SplatMesh, textSplats, dyno, transcodeSpz, isMobile, isPcSogs, LN_SCALE_MIN, LN_SCALE_MAX } from "@sparkjsdev/spark"; + import { constructGrid, SparkControls, SparkRenderer, SplatMesh, textSplats, dyno, transcodeSpz, isMobile, LN_SCALE_MIN, LN_SCALE_MAX } from "@sparkjsdev/spark"; import { getAssetFileURL } from "../js/get-asset-url.js"; const scene = new THREE.Scene(); diff --git a/examples/splat-painter/index.html b/examples/splat-painter/index.html index 133135cb..0e56cf5c 100644 --- a/examples/splat-painter/index.html +++ b/examples/splat-painter/index.html @@ -62,7 +62,7 @@ RgbaArray, readRgbaArray, SparkControls, - SpzWriter, + writeSpz, unpackSplat, PackedSplats, } from "@sparkjsdev/spark"; @@ -405,7 +405,7 @@ alert("No splat mesh loaded for export."); return; } - + try { console.log("Starting SPZ export with painted changes..."); @@ -416,7 +416,7 @@ }); currentSplatMesh.updateGenerator(); } - + const rgba = new RgbaArray(); rgba.render({ renderer, @@ -470,7 +470,7 @@ unpacked.color.g = rgbaBytes[rgbaOffset + 1] / 255; unpacked.color.b = rgbaBytes[rgbaOffset + 2] / 255; unpacked.opacity = opacity; - + // Push to new PackedSplats newPackedSplats.pushSplat( unpacked.center, @@ -482,39 +482,13 @@ processedCount++; } - + console.log(`Processed ${processedCount} splats`); // Now export the PackedSplats to SPZ - console.log("Creating SPZ writer..."); - const maxSh = ioOptions.maxSh; - const spzWriter = new SpzWriter({ - numSplats: nonZeroCount, - shDegree: maxSh, - fractionalBits: ioOptions.fractionalBits, - flagAntiAlias: true, - }); - console.log("Writing splats to SPZ..."); - // Iterate through the new packed array - for (let i = 0; i < nonZeroCount; i++) { - const unpacked = unpackSplat( - newPackedSplats.packedArray, - i, - newPackedSplats.splatEncoding - ); - - spzWriter.setCenter(i, unpacked.center.x, unpacked.center.y, unpacked.center.z); - spzWriter.setScale(i, unpacked.scales.x, unpacked.scales.y, unpacked.scales.z); - spzWriter.setQuat(i, unpacked.quaternion.x, unpacked.quaternion.y, unpacked.quaternion.z, unpacked.quaternion.w); - spzWriter.setAlpha(i, unpacked.opacity); - spzWriter.setRgb(i, unpacked.color.r, unpacked.color.g, unpacked.color.b); - } - const spzBytes = await spzWriter.finalize(); - if (spzWriter.clippedCount > 0) { - console.log(`Clipped ${spzWriter.clippedCount} splats. Consider decreasing fractional-bits from ${ioOptions.fractionalBits} to reduce clipping.`); - } - + const { fileBytes: spzBytes } = writeSpz(newPackedSplats, ioOptions.maxSh, ioOptions.fractionalBits); + console.log("Creating download..."); const blob = new Blob([spzBytes], { type: "application/octet-stream" }); const url = URL.createObjectURL(blob); @@ -531,7 +505,7 @@ } }, }; - + const ioFolder = gui.addFolder("I/O"); ioFolder.add(ioOptions, "loadFile").name("Load Splats (SPZ/PLY)"); ioFolder.add(ioOptions, "saveToSpz").name("Save Splats (SPZ)"); diff --git a/rust/Cargo.lock b/rust/Cargo.lock index e85117aa..625c7b67 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1073,6 +1073,7 @@ dependencies = [ "itertools", "js-sys", "ordered-float", + "serde", "serde-wasm-bindgen", "serde_json", "smallvec", diff --git a/rust/spark-rs/Cargo.toml b/rust/spark-rs/Cargo.toml index 7c4f3e13..bab51740 100644 --- a/rust/spark-rs/Cargo.toml +++ b/rust/spark-rs/Cargo.toml @@ -26,5 +26,6 @@ web-sys = { workspace = true, features = ["Window", "Performance"] } spark-lib = { path = "../spark-lib" } serde-wasm-bindgen.workspace = true serde_json.workspace = true +serde.workspace = true itertools.workspace = true console_error_panic_hook = "0.1" diff --git a/rust/spark-rs/src/lib.rs b/rust/spark-rs/src/lib.rs index 4aa1b10e..68d11d53 100644 --- a/rust/spark-rs/src/lib.rs +++ b/rust/spark-rs/src/lib.rs @@ -2,6 +2,8 @@ use std::cell::RefCell; use js_sys::{Array, Float32Array, Object, Reflect, Uint8Array, Uint16Array, Uint32Array}; use spark_lib::decoder::{ChunkReceiver, MultiDecoder, SplatEncoding, SplatFileType, SplatGetter}; +use spark_lib::spz::SpzEncoder; +use spark_lib::gsplat::{GsplatSH1,GsplatSH2,GsplatSH3}; use spark_lib::gsplat::GsplatArray as GsplatArrayInner; use spark_lib::csplat::CsplatArray as CsplatArrayInner; use spark_lib::tsplat::TsplatArray; @@ -16,6 +18,9 @@ use raycast::{raycast_packed_ellipsoids, raycast_ext_ellipsoids}; mod sort; use sort::{sort_internal, SortBuffers, sort32_internal, Sort32Buffers}; +mod transform; +use transform::{transform_gsplatarray, TransformOptions}; + mod decoder; mod packed_splats; mod ext_splats; @@ -179,6 +184,7 @@ impl GsplatArray { } } + #[wasm_bindgen] impl GsplatArray { pub fn len(&self) -> usize { @@ -257,6 +263,32 @@ impl GsplatArray { pub fn inject_rgba8(&mut self, rgba: Uint8Array) { self.inner.inject_rgba8(&rgba.to_vec()); } + + pub fn transform(&mut self, transform: JsValue) -> Result<(), JsValue> { + let transform_options: TransformOptions = serde_wasm_bindgen::from_value(transform)?; + transform_gsplatarray(&mut self.inner, transform_options); + Ok(()) + } + + pub fn concat(&mut self, other: &mut GsplatArray) -> Result<(), JsValue> { + for i in 0..other.inner.len() { + let sh1 = if other.maxShDegree >= 1 { other.inner.sh1[i].clone() } else { GsplatSH1::default() }; + let sh2 = if other.maxShDegree >= 2 { other.inner.sh2[i].clone() } else { GsplatSH2::default() }; + let sh3 = if other.maxShDegree >= 3 { other.inner.sh3[i].clone() } else { GsplatSH3::default() }; + self.inner.push_splat(other.inner.get(i).clone(), Some(sh1), Some(sh2), Some(sh3)); + } + Ok(()) + } + + pub fn encode_to_spz(mut self, max_sh: u32, fractional_bits: u8) -> Result { + self.inner.clamp_sh_degree(max_sh as usize); + self.maxShDegree = self.inner.max_sh_degree; + let encoded = match SpzEncoder::new(self.inner).with_max_sh(max_sh as usize).with_fractional_bits(fractional_bits).encode() { + Err(err) => { return Err(JsValue::from(err.to_string())); }, + Ok(encoded) => encoded + }; + Ok(Uint8Array::from(encoded.as_slice())) + } } #[wasm_bindgen] diff --git a/rust/spark-rs/src/transform.rs b/rust/spark-rs/src/transform.rs new file mode 100644 index 00000000..773c649a --- /dev/null +++ b/rust/spark-rs/src/transform.rs @@ -0,0 +1,73 @@ +use glam::{Vec3A, Quat}; + +use spark_lib::gsplat::GsplatArray; +use spark_lib::tsplat::TsplatArray; +use spark_lib::tsplat::Tsplat; +use serde::{Deserialize, Serialize}; + +use spark_lib::decoder::SplatReceiver; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TransformOptions { + pub translation: [f32; 3], + pub rotation: [f32; 4], + pub scale: f32, + pub clip: Option<[f32; 6]>, + #[serde(rename = "opacityThreshold")] + pub opacity_threshold: f32, +} + +pub fn transform_gsplatarray(gsplats: &mut GsplatArray, transform_options: TransformOptions) { + let translation = Vec3A::from_array(transform_options.translation); + let quaternion = Quat::from_array(transform_options.rotation); + let scale = Vec3A::splat(transform_options.scale); + + let clip = transform_options.clip.map(|clip| (Vec3A::from_slice(&clip[..3]), Vec3A::from_slice(&clip[3..]))); + + let mut out_index = 0; + for splat_index in 0..gsplats.splats.len() { + let in_splat = gsplats.get(splat_index); + + let mut center = in_splat.center(); + // Transform center + center = quaternion * (center * scale) + translation; + + // Check clip box + let clipped = match clip { + Some((min, max)) => (center.cmplt(min)).any() || (center.cmpgt(max)).any(), + None => false + }; + if clipped { + continue; + } + + // Check opacity threshold + let opacity = in_splat.opacity(); + if opacity < transform_options.opacity_threshold { + continue; + } + + let mut scales = in_splat.scales(); + let mut quat = in_splat.quaternion(); + let rgb = in_splat.rgb(); + + gsplats.set_center(out_index, 1, ¢er.to_array()); + + scales *= scale; + gsplats.set_scale(out_index, 1, &scales.to_array()); + + quat *= quaternion; + gsplats.set_quat(out_index, 1, &quat.to_array()); + + gsplats.set_rgb(out_index, 1, &rgb.to_array()); + gsplats.set_opacity(out_index, 1, &[opacity]); + + gsplats.set_sh1(out_index, 1, gsplats.get_sh1(splat_index).as_slice()); + gsplats.set_sh2(out_index, 1, gsplats.get_sh2(splat_index).as_slice()); + gsplats.set_sh3(out_index, 1, gsplats.get_sh3(splat_index).as_slice()); + + out_index += 1; + } + + gsplats.truncate(out_index); +} \ No newline at end of file diff --git a/src/SplatLoader.ts b/src/SplatLoader.ts index b5866099..c2a509e2 100644 --- a/src/SplatLoader.ts +++ b/src/SplatLoader.ts @@ -5,7 +5,6 @@ import { PackedSplats, type PackedSplatsOptions } from "./PackedSplats"; import { SplatMesh } from "./SplatMesh"; import { workerPool } from "./SplatWorker"; import { type SplatEncoding, SplatFileType } from "./defines"; -import { PlyReader } from "./ply"; import { decompressPartialGzip, getTextureSize } from "./utils"; // SplatLoader implements the THREE.Loader interface and supports loading a variety @@ -340,65 +339,6 @@ export class SplatLoader extends Loader { } } -async function fetchWithProgress( - request: Request, - onProgress?: (event: ProgressEvent) => void, -) { - const response = await fetch(request); - if (!response.ok) { - throw new Error( - `${response.status} "${response.statusText}" fetching URL: ${request.url}`, - ); - } - if (!response.body) { - throw new Error(`Response body is null for URL: ${request.url}`); - } - - const reader = response.body.getReader(); - let loaded = 0; - const chunks: Uint8Array[] = []; - try { - const contentLength = Number.parseInt( - response.headers.get("Content-Length") || "0", - ); - const total = Number.isNaN(contentLength) ? 0 : contentLength; - - while (true) { - const { done, value } = await reader.read(); - if (done) { - break; - } - chunks.push(value); - loaded += value.length; - - if (onProgress) { - onProgress( - new ProgressEvent("progress", { - lengthComputable: total !== 0, - loaded, - total, - }), - ); - } - } - } catch (err) { - try { - const reason = err instanceof Error ? err.message : "Unknown error"; - await reader.cancel(reason); - } catch {} - throw err; - } - - // Combine chunks into a single buffer - const bytes = new Uint8Array(loaded); - let offset = 0; - for (const chunk of chunks) { - bytes.set(chunk, offset); - offset += chunk.length; - } - return bytes.buffer; -} - export function getSplatFileType( fileBytes: Uint8Array, ): SplatFileType | undefined { diff --git a/src/antisplat.ts b/src/antisplat.ts deleted file mode 100644 index 8b52864d..00000000 --- a/src/antisplat.ts +++ /dev/null @@ -1,125 +0,0 @@ -import type { SplatEncoding } from "./defines"; -import { computeMaxSplats, setPackedSplat } from "./utils"; - -export function decodeAntiSplat( - fileBytes: Uint8Array, - initNumSplats: (numSplats: number) => void, - splatCallback: ( - index: number, - x: number, - y: number, - z: number, - scaleX: number, - scaleY: number, - scaleZ: number, - quatX: number, - quatY: number, - quatZ: number, - quatW: number, - opacity: number, - r: number, - g: number, - b: number, - ) => void, -) { - const numSplats = Math.floor(fileBytes.length / 32); // 32 bytes per splat - if (numSplats * 32 !== fileBytes.length) { - throw new Error("Invalid .splat file size"); - } - initNumSplats(numSplats); - - const f32 = new Float32Array(fileBytes.buffer); - for (let i = 0; i < numSplats; ++i) { - const i32 = i * 32; - const i8 = i * 8; - const x = f32[i8 + 0]; - const y = f32[i8 + 1]; - const z = f32[i8 + 2]; - const scaleX = f32[i8 + 3]; - const scaleY = f32[i8 + 4]; - const scaleZ = f32[i8 + 5]; - const r = fileBytes[i32 + 24] / 255; - const g = fileBytes[i32 + 25] / 255; - const b = fileBytes[i32 + 26] / 255; - const opacity = fileBytes[i32 + 27] / 255; - const quatW = (fileBytes[i32 + 28] - 128) / 128; - const quatX = (fileBytes[i32 + 29] - 128) / 128; - const quatY = (fileBytes[i32 + 30] - 128) / 128; - const quatZ = (fileBytes[i32 + 31] - 128) / 128; - splatCallback( - i, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ); - } -} - -export function unpackAntiSplat( - fileBytes: Uint8Array, - splatEncoding: SplatEncoding, -): { - packedArray: Uint32Array; - numSplats: number; -} { - let numSplats = 0; - let maxSplats = 0; - let packedArray = new Uint32Array(0); - decodeAntiSplat( - fileBytes, - (cbNumSplats) => { - numSplats = cbNumSplats; - maxSplats = computeMaxSplats(numSplats); - packedArray = new Uint32Array(maxSplats * 4); - }, - ( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ) => { - setPackedSplat( - packedArray, - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - splatEncoding, - ); - }, - ); - return { packedArray, numSplats }; -} diff --git a/src/index.ts b/src/index.ts index 4a68545e..2ce39cc9 100644 --- a/src/index.ts +++ b/src/index.ts @@ -11,10 +11,8 @@ export { RgbaArray, readRgbaArray } from "./RgbaArray"; export { SplatLoader, getSplatFileType, - isPcSogs, } from "./SplatLoader"; -export { PlyReader } from "./ply"; -export { SpzReader, SpzWriter, transcodeSpz } from "./spz"; +export { transcodeSpz, writeSpz } from "./spz"; export { PackedSplats, type PackedSplatsOptions } from "./PackedSplats"; export { ExtSplats, type ExtSplatsOptions } from "./ExtSplats"; diff --git a/src/ksplat.ts b/src/ksplat.ts deleted file mode 100644 index 3c46e58e..00000000 --- a/src/ksplat.ts +++ /dev/null @@ -1,636 +0,0 @@ -import type { SplatEncoding } from "./defines"; -import { - computeMaxSplats, - encodeSh1Rgb, - encodeSh2Rgb, - encodeSh3Rgb, - fromHalf, - setPackedSplat, -} from "./utils"; - -type KsplatCompression = { - bytesPerCenter: number; - bytesPerScale: number; - bytesPerRotation: number; - bytesPerColor: number; - bytesPerSphericalHarmonicsComponent: number; - scaleOffsetBytes: number; - rotationOffsetBytes: number; - colorOffsetBytes: number; - sphericalHarmonicsOffsetBytes: number; - scaleRange: number; -}; - -const KSPLAT_COMPRESSION: Record = { - 0: { - bytesPerCenter: 12, - bytesPerScale: 12, - bytesPerRotation: 16, - bytesPerColor: 4, - bytesPerSphericalHarmonicsComponent: 4, - scaleOffsetBytes: 12, - rotationOffsetBytes: 24, - colorOffsetBytes: 40, - sphericalHarmonicsOffsetBytes: 44, - scaleRange: 1, - }, - 1: { - bytesPerCenter: 6, - bytesPerScale: 6, - bytesPerRotation: 8, - bytesPerColor: 4, - bytesPerSphericalHarmonicsComponent: 2, - scaleOffsetBytes: 6, - rotationOffsetBytes: 12, - colorOffsetBytes: 20, - sphericalHarmonicsOffsetBytes: 24, - scaleRange: 32767, - }, - 2: { - bytesPerCenter: 6, - bytesPerScale: 6, - bytesPerRotation: 8, - bytesPerColor: 4, - bytesPerSphericalHarmonicsComponent: 1, - scaleOffsetBytes: 6, - rotationOffsetBytes: 12, - colorOffsetBytes: 20, - sphericalHarmonicsOffsetBytes: 24, - scaleRange: 32767, - }, -}; - -const KSPLAT_SH_DEGREE_TO_COMPONENTS: Record = { - 0: 0, - 1: 9, - 2: 24, - 3: 45, -}; - -export function decodeKsplat( - fileBytes: Uint8Array, - initNumSplats: (numSplats: number) => void, - splatCallback: ( - index: number, - x: number, - y: number, - z: number, - scaleX: number, - scaleY: number, - scaleZ: number, - quatX: number, - quatY: number, - quatZ: number, - quatW: number, - opacity: number, - r: number, - g: number, - b: number, - ) => void, - shCallback?: ( - index: number, - sh1: Float32Array, - sh2?: Float32Array, - sh3?: Float32Array, - ) => void, -) { - const HEADER_BYTES = 4096; - const SECTION_BYTES = 1024; - - let headerOffset = 0; - const header = new DataView(fileBytes.buffer, headerOffset, HEADER_BYTES); - headerOffset += HEADER_BYTES; - - const versionMajor = header.getUint8(0); - const versionMinor = header.getUint8(1); - if (versionMajor !== 0 || versionMinor < 1) { - throw new Error( - `Unsupported .ksplat version: ${versionMajor}.${versionMinor}`, - ); - } - const maxSectionCount = header.getUint32(4, true); - // const sectionCount = header.getUint32(8, true); - // const maxSplatCount = header.getUint32(12, true); - const splatCount = header.getUint32(16, true); - const compressionLevel = header.getUint16(20, true); - if (compressionLevel < 0 || compressionLevel > 2) { - throw new Error(`Invalid .ksplat compression level: ${compressionLevel}`); - } - // const sceneCenterX = header.getFloat32(24, true); - // const sceneCenterY = header.getFloat32(28, true); - // const sceneCenterZ = header.getFloat32(32, true); - const minSphericalHarmonicsCoeff = header.getFloat32(36, true) || -1.5; - const maxSphericalHarmonicsCoeff = header.getFloat32(40, true) || 1.5; - - const numSplats = splatCount; - initNumSplats(numSplats); - const maxSplats = computeMaxSplats(numSplats); - const packedArray = new Uint32Array(maxSplats * 4); - const extra: Record = {}; - - let sectionBase = HEADER_BYTES + maxSectionCount * SECTION_BYTES; - - for (let section = 0; section < maxSectionCount; ++section) { - const section = new DataView(fileBytes.buffer, headerOffset, SECTION_BYTES); - headerOffset += SECTION_BYTES; - - const sectionSplatCount = section.getUint32(0, true); - const sectionMaxSplatCount = section.getUint32(4, true); - const bucketSize = section.getUint32(8, true); - const bucketCount = section.getUint32(12, true); - const bucketBlockSize = section.getFloat32(16, true); - const bucketStorageSizeBytes = section.getUint16(20, true); - const compressionScaleRange = - (section.getUint32(24, true) || - KSPLAT_COMPRESSION[compressionLevel]?.scaleRange) ?? - 1; - const fullBucketCount = section.getUint32(32, true); - const fullBucketSplats = fullBucketCount * bucketSize; - const partiallyFilledBucketCount = section.getUint32(36, true); - const bucketsMetaDataSizeBytes = partiallyFilledBucketCount * 4; - const bucketsStorageSizeBytes = - bucketStorageSizeBytes * bucketCount + bucketsMetaDataSizeBytes; - const sphericalHarmonicsDegree = section.getUint16(40, true); - const shComponents = - KSPLAT_SH_DEGREE_TO_COMPONENTS[sphericalHarmonicsDegree]; - - const { - bytesPerCenter, - bytesPerScale, - bytesPerRotation, - bytesPerColor, - bytesPerSphericalHarmonicsComponent, - scaleOffsetBytes, - rotationOffsetBytes, - colorOffsetBytes, - sphericalHarmonicsOffsetBytes, - } = KSPLAT_COMPRESSION[compressionLevel]; - const bytesPerSplat = - bytesPerCenter + - bytesPerScale + - bytesPerRotation + - bytesPerColor + - shComponents * bytesPerSphericalHarmonicsComponent; - const splatDataStorageSizeBytes = bytesPerSplat * sectionMaxSplatCount; - const storageSizeBytes = - splatDataStorageSizeBytes + bucketsStorageSizeBytes; - - const sh1Index = [0, 3, 6, 1, 4, 7, 2, 5, 8]; - const sh2Index = [ - 9, 14, 19, 10, 15, 20, 11, 16, 21, 12, 17, 22, 13, 18, 23, - ]; - const sh3Index = [ - 24, 31, 38, 25, 32, 39, 26, 33, 40, 27, 34, 41, 28, 35, 42, 29, 36, 43, - 30, 37, 44, - ]; - const sh1 = - sphericalHarmonicsDegree >= 1 ? new Float32Array(3 * 3) : undefined; - const sh2 = - sphericalHarmonicsDegree >= 2 ? new Float32Array(5 * 3) : undefined; - const sh3 = - sphericalHarmonicsDegree >= 3 ? new Float32Array(7 * 3) : undefined; - - const compressionScaleFactor = bucketBlockSize / 2 / compressionScaleRange; - const bucketsBase = sectionBase + bucketsMetaDataSizeBytes; - const dataBase = sectionBase + bucketsStorageSizeBytes; - const data = new DataView( - fileBytes.buffer, - dataBase, - splatDataStorageSizeBytes, - ); - const bucketArray = new Float32Array( - fileBytes.buffer, - bucketsBase, - bucketCount * 3, - ); - const partiallyFilledBucketLengths = new Uint32Array( - fileBytes.buffer, - sectionBase, - partiallyFilledBucketCount, - ); - - function getSh(splatOffset: number, component: number) { - if (compressionLevel === 0) { - return data.getFloat32( - splatOffset + sphericalHarmonicsOffsetBytes + component * 4, - true, - ); - } - if (compressionLevel === 1) { - return fromHalf( - data.getUint16( - splatOffset + sphericalHarmonicsOffsetBytes + component * 2, - true, - ), - ); - } - const t = - data.getUint8(splatOffset + sphericalHarmonicsOffsetBytes + component) / - 255; - return ( - minSphericalHarmonicsCoeff + - t * (maxSphericalHarmonicsCoeff - minSphericalHarmonicsCoeff) - ); - } - - let partialBucketIndex = fullBucketCount; - let partialBucketBase = fullBucketSplats; - - for (let i = 0; i < sectionSplatCount; ++i) { - const splatOffset = i * bytesPerSplat; - - let bucketIndex: number; - if (i < fullBucketSplats) { - bucketIndex = Math.floor(i / bucketSize); - } else { - const bucketLength = - partiallyFilledBucketLengths[partialBucketIndex - fullBucketCount]; - if (i >= partialBucketBase + bucketLength) { - partialBucketIndex += 1; - partialBucketBase += bucketLength; - } - bucketIndex = partialBucketIndex; - } - - const x = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 0, true) - : (data.getUint16(splatOffset + 0, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 0]; - const y = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 4, true) - : (data.getUint16(splatOffset + 2, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 1]; - const z = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 8, true) - : (data.getUint16(splatOffset + 4, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 2]; - - const scaleX = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 0, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 0, true)); - const scaleY = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 4, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 2, true)); - const scaleZ = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 8, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 4, true)); - - const quatW = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 0, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 0, true), - ); - const quatX = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 4, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 2, true), - ); - const quatY = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 8, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 4, true), - ); - const quatZ = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 12, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 6, true), - ); - - const r = data.getUint8(splatOffset + colorOffsetBytes + 0) / 255; - const g = data.getUint8(splatOffset + colorOffsetBytes + 1) / 255; - const b = data.getUint8(splatOffset + colorOffsetBytes + 2) / 255; - const opacity = data.getUint8(splatOffset + colorOffsetBytes + 3) / 255; - - splatCallback( - i, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ); - - if (sphericalHarmonicsDegree >= 1 && sh1) { - for (const [i, key] of sh1Index.entries()) { - sh1[i] = getSh(splatOffset, key); - } - if (sh2) { - for (const [i, key] of sh2Index.entries()) { - sh2[i] = getSh(splatOffset, key); - } - } - if (sh3) { - for (const [i, key] of sh3Index.entries()) { - sh3[i] = getSh(splatOffset, key); - } - } - shCallback?.(i, sh1, sh2, sh3); - } - } - sectionBase += storageSizeBytes; - } -} - -export function unpackKsplat( - fileBytes: Uint8Array, - splatEncoding: SplatEncoding, -): { - packedArray: Uint32Array; - numSplats: number; - extra: Record; -} { - const HEADER_BYTES = 4096; - const SECTION_BYTES = 1024; - - let headerOffset = 0; - const header = new DataView(fileBytes.buffer, headerOffset, HEADER_BYTES); - headerOffset += HEADER_BYTES; - - const versionMajor = header.getUint8(0); - const versionMinor = header.getUint8(1); - if (versionMajor !== 0 || versionMinor < 1) { - throw new Error( - `Unsupported .ksplat version: ${versionMajor}.${versionMinor}`, - ); - } - const maxSectionCount = header.getUint32(4, true); - // const sectionCount = header.getUint32(8, true); - // const maxSplatCount = header.getUint32(12, true); - const splatCount = header.getUint32(16, true); - const compressionLevel = header.getUint16(20, true); - if (compressionLevel < 0 || compressionLevel > 2) { - throw new Error(`Invalid .ksplat compression level: ${compressionLevel}`); - } - // const sceneCenterX = header.getFloat32(24, true); - // const sceneCenterY = header.getFloat32(28, true); - // const sceneCenterZ = header.getFloat32(32, true); - const minSphericalHarmonicsCoeff = header.getFloat32(36, true) || -1.5; - const maxSphericalHarmonicsCoeff = header.getFloat32(40, true) || 1.5; - - const numSplats = splatCount; - const maxSplats = computeMaxSplats(numSplats); - const packedArray = new Uint32Array(maxSplats * 4); - const extra: Record = {}; - - let sectionBase = HEADER_BYTES + maxSectionCount * SECTION_BYTES; - - for (let section = 0; section < maxSectionCount; ++section) { - const section = new DataView(fileBytes.buffer, headerOffset, SECTION_BYTES); - headerOffset += SECTION_BYTES; - - const sectionSplatCount = section.getUint32(0, true); - const sectionMaxSplatCount = section.getUint32(4, true); - const bucketSize = section.getUint32(8, true); - const bucketCount = section.getUint32(12, true); - const bucketBlockSize = section.getFloat32(16, true); - const bucketStorageSizeBytes = section.getUint16(20, true); - const compressionScaleRange = - (section.getUint32(24, true) || - KSPLAT_COMPRESSION[compressionLevel]?.scaleRange) ?? - 1; - const fullBucketCount = section.getUint32(32, true); - const fullBucketSplats = fullBucketCount * bucketSize; - const partiallyFilledBucketCount = section.getUint32(36, true); - const bucketsMetaDataSizeBytes = partiallyFilledBucketCount * 4; - const bucketsStorageSizeBytes = - bucketStorageSizeBytes * bucketCount + bucketsMetaDataSizeBytes; - const sphericalHarmonicsDegree = section.getUint16(40, true); - const shComponents = - KSPLAT_SH_DEGREE_TO_COMPONENTS[sphericalHarmonicsDegree]; - - const { - bytesPerCenter, - bytesPerScale, - bytesPerRotation, - bytesPerColor, - bytesPerSphericalHarmonicsComponent, - scaleOffsetBytes, - rotationOffsetBytes, - colorOffsetBytes, - sphericalHarmonicsOffsetBytes, - } = KSPLAT_COMPRESSION[compressionLevel]; - const bytesPerSplat = - bytesPerCenter + - bytesPerScale + - bytesPerRotation + - bytesPerColor + - shComponents * bytesPerSphericalHarmonicsComponent; - const splatDataStorageSizeBytes = bytesPerSplat * sectionMaxSplatCount; - const storageSizeBytes = - splatDataStorageSizeBytes + bucketsStorageSizeBytes; - - const sh1Index = [0, 3, 6, 1, 4, 7, 2, 5, 8]; - const sh2Index = [ - 9, 14, 19, 10, 15, 20, 11, 16, 21, 12, 17, 22, 13, 18, 23, - ]; - const sh3Index = [ - 24, 31, 38, 25, 32, 39, 26, 33, 40, 27, 34, 41, 28, 35, 42, 29, 36, 43, - 30, 37, 44, - ]; - const sh1 = - sphericalHarmonicsDegree >= 1 ? new Float32Array(3 * 3) : undefined; - const sh2 = - sphericalHarmonicsDegree >= 2 ? new Float32Array(5 * 3) : undefined; - const sh3 = - sphericalHarmonicsDegree >= 3 ? new Float32Array(7 * 3) : undefined; - - const compressionScaleFactor = bucketBlockSize / 2 / compressionScaleRange; - const bucketsBase = sectionBase + bucketsMetaDataSizeBytes; - const dataBase = sectionBase + bucketsStorageSizeBytes; - const data = new DataView( - fileBytes.buffer, - dataBase, - splatDataStorageSizeBytes, - ); - const bucketArray = new Float32Array( - fileBytes.buffer, - bucketsBase, - bucketCount * 3, - ); - const partiallyFilledBucketLengths = new Uint32Array( - fileBytes.buffer, - sectionBase, - partiallyFilledBucketCount, - ); - - function getSh(splatOffset: number, component: number) { - if (compressionLevel === 0) { - return data.getFloat32( - splatOffset + sphericalHarmonicsOffsetBytes + component * 4, - true, - ); - } - if (compressionLevel === 1) { - return fromHalf( - data.getUint16( - splatOffset + sphericalHarmonicsOffsetBytes + component * 2, - true, - ), - ); - } - const t = - data.getUint8(splatOffset + sphericalHarmonicsOffsetBytes + component) / - 255; - return ( - minSphericalHarmonicsCoeff + - t * (maxSphericalHarmonicsCoeff - minSphericalHarmonicsCoeff) - ); - } - - let partialBucketIndex = fullBucketCount; - let partialBucketBase = fullBucketSplats; - - for (let i = 0; i < sectionSplatCount; ++i) { - const splatOffset = i * bytesPerSplat; - - let bucketIndex: number; - if (i < fullBucketSplats) { - bucketIndex = Math.floor(i / bucketSize); - } else { - const bucketLength = - partiallyFilledBucketLengths[partialBucketIndex - fullBucketCount]; - if (i >= partialBucketBase + bucketLength) { - partialBucketIndex += 1; - partialBucketBase += bucketLength; - } - bucketIndex = partialBucketIndex; - } - - const x = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 0, true) - : (data.getUint16(splatOffset + 0, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 0]; - const y = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 4, true) - : (data.getUint16(splatOffset + 2, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 1]; - const z = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 8, true) - : (data.getUint16(splatOffset + 4, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 2]; - - const scaleX = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 0, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 0, true)); - const scaleY = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 4, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 2, true)); - const scaleZ = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 8, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 4, true)); - - const quatW = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 0, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 0, true), - ); - const quatX = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 4, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 2, true), - ); - const quatY = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 8, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 4, true), - ); - const quatZ = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 12, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 6, true), - ); - - const r = data.getUint8(splatOffset + colorOffsetBytes + 0) / 255; - const g = data.getUint8(splatOffset + colorOffsetBytes + 1) / 255; - const b = data.getUint8(splatOffset + colorOffsetBytes + 2) / 255; - const opacity = data.getUint8(splatOffset + colorOffsetBytes + 3) / 255; - - setPackedSplat( - packedArray, - i, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - splatEncoding, - ); - - if (sphericalHarmonicsDegree >= 1) { - if (sh1) { - if (!extra.sh1) { - extra.sh1 = new Uint32Array(numSplats * 2); - } - for (const [i, key] of sh1Index.entries()) { - sh1[i] = getSh(splatOffset, key); - } - encodeSh1Rgb(extra.sh1 as Uint32Array, i, sh1, splatEncoding); - } - if (sh2) { - if (!extra.sh2) { - extra.sh2 = new Uint32Array(numSplats * 4); - } - for (const [i, key] of sh2Index.entries()) { - sh2[i] = getSh(splatOffset, key); - } - encodeSh2Rgb(extra.sh2 as Uint32Array, i, sh2, splatEncoding); - } - if (sh3) { - if (!extra.sh3) { - extra.sh3 = new Uint32Array(numSplats * 4); - } - for (const [i, key] of sh3Index.entries()) { - sh3[i] = getSh(splatOffset, key); - } - encodeSh3Rgb(extra.sh3 as Uint32Array, i, sh3, splatEncoding); - } - } - } - sectionBase += storageSizeBytes; - } - return { packedArray, numSplats, extra }; -} diff --git a/src/pcsogs.ts b/src/pcsogs.ts deleted file mode 100644 index ca883a71..00000000 --- a/src/pcsogs.ts +++ /dev/null @@ -1,387 +0,0 @@ -import { unzip } from "fflate"; -import { - type PcSogsJson, - type PcSogsV2Json, - tryPcSogsZip, -} from "./SplatLoader"; -import type { SplatEncoding } from "./defines"; -import { - computeMaxSplats, - encodeSh1Rgb, - encodeSh2Rgb, - encodeSh3Rgb, - setPackedSplatCenter, - setPackedSplatQuat, - setPackedSplatRgba, - setPackedSplatScales, -} from "./utils"; - -export async function unpackPcSogs( - json: PcSogsJson | PcSogsV2Json, - extraFiles: Record, - splatEncoding: SplatEncoding, -): Promise<{ - packedArray: Uint32Array; - numSplats: number; - extra: Record; -}> { - const isVersion2 = "version" in json; - - if (!isVersion2 && json.quats.encoding !== "quaternion_packed") { - throw new Error("Unsupported quaternion encoding"); - } - - const numSplats = isVersion2 ? json.count : json.means.shape[0]; - const maxSplats = computeMaxSplats(numSplats); - const packedArray = new Uint32Array(maxSplats * 4); - const extra: Record = {}; - - const meansPromise = Promise.all([ - decodeImageRgba(extraFiles[json.means.files[0]]), - decodeImageRgba(extraFiles[json.means.files[1]]), - ]).then((means) => { - for (let i = 0; i < numSplats; ++i) { - const i4 = i * 4; - const fx = (means[0][i4 + 0] + (means[1][i4 + 0] << 8)) / 65535; - const fy = (means[0][i4 + 1] + (means[1][i4 + 1] << 8)) / 65535; - const fz = (means[0][i4 + 2] + (means[1][i4 + 2] << 8)) / 65535; - let x = - json.means.mins[0] + (json.means.maxs[0] - json.means.mins[0]) * fx; - let y = - json.means.mins[1] + (json.means.maxs[1] - json.means.mins[1]) * fy; - let z = - json.means.mins[2] + (json.means.maxs[2] - json.means.mins[2]) * fz; - x = Math.sign(x) * (Math.exp(Math.abs(x)) - 1); - y = Math.sign(y) * (Math.exp(Math.abs(y)) - 1); - z = Math.sign(z) * (Math.exp(Math.abs(z)) - 1); - setPackedSplatCenter(packedArray, i, x, y, z); - } - }); - - const scalesPromise = decodeImageRgba(extraFiles[json.scales.files[0]]).then( - (scales) => { - let xLookup: number[]; - let yLookup: number[]; - let zLookup: number[]; - - if (isVersion2) { - xLookup = - yLookup = - zLookup = - json.scales.codebook.map((x) => Math.exp(x)); - } else { - xLookup = new Array(256) - .fill(0) - .map( - (_, i) => - json.scales.mins[0] + - (json.scales.maxs[0] - json.scales.mins[0]) * (i / 255), - ) - .map((x) => Math.exp(x)); - yLookup = new Array(256) - .fill(0) - .map( - (_, i) => - json.scales.mins[1] + - (json.scales.maxs[1] - json.scales.mins[1]) * (i / 255), - ) - .map((x) => Math.exp(x)); - zLookup = new Array(256) - .fill(0) - .map( - (_, i) => - json.scales.mins[2] + - (json.scales.maxs[2] - json.scales.mins[2]) * (i / 255), - ) - .map((x) => Math.exp(x)); - } - - for (let i = 0; i < numSplats; ++i) { - const i4 = i * 4; - setPackedSplatScales( - packedArray, - i, - xLookup[scales[i4 + 0]], - yLookup[scales[i4 + 1]], - zLookup[scales[i4 + 2]], - splatEncoding, - ); - } - }, - ); - - const quatsPromise = decodeImageRgba(extraFiles[json.quats.files[0]]).then( - (quats) => { - const SQRT2 = Math.sqrt(2); - const lookup = new Array(256) - .fill(0) - .map((_, i) => (i / 255 - 0.5) * SQRT2); - - for (let i = 0; i < numSplats; ++i) { - const i4 = i * 4; - const r0 = lookup[quats[i4 + 0]]; - const r1 = lookup[quats[i4 + 1]]; - const r2 = lookup[quats[i4 + 2]]; - const rr = Math.sqrt(Math.max(0, 1.0 - r0 * r0 - r1 * r1 - r2 * r2)); - const rOrder = quats[i4 + 3] - 252; - const quatX = rOrder === 0 ? r0 : rOrder === 1 ? rr : r1; - const quatY = rOrder <= 1 ? r1 : rOrder === 2 ? rr : r2; - const quatZ = rOrder <= 2 ? r2 : rr; - const quatW = rOrder === 0 ? rr : r0; - setPackedSplatQuat(packedArray, i, quatX, quatY, quatZ, quatW); - } - }, - ); - const sh0Promise = decodeImageRgba(extraFiles[json.sh0.files[0]]).then( - (sh0) => { - const SH_C0 = 0.28209479177387814; - let rLookup: number[]; - let gLookup: number[]; - let bLookup: number[]; - let aLookup: number[]; - - if (isVersion2) { - rLookup = - gLookup = - bLookup = - json.sh0.codebook.map((x) => SH_C0 * x + 0.5); - aLookup = new Array(256).fill(0).map((_, i) => i / 255); - } else { - rLookup = new Array(256) - .fill(0) - .map( - (_, i) => - json.sh0.mins[0] + - (json.sh0.maxs[0] - json.sh0.mins[0]) * (i / 255), - ) - .map((x) => SH_C0 * x + 0.5); - gLookup = new Array(256) - .fill(0) - .map( - (_, i) => - json.sh0.mins[1] + - (json.sh0.maxs[1] - json.sh0.mins[1]) * (i / 255), - ) - .map((x) => SH_C0 * x + 0.5); - bLookup = new Array(256) - .fill(0) - .map( - (_, i) => - json.sh0.mins[2] + - (json.sh0.maxs[2] - json.sh0.mins[2]) * (i / 255), - ) - .map((x) => SH_C0 * x + 0.5); - aLookup = new Array(256) - .fill(0) - .map( - (_, i) => - json.sh0.mins[3] + - (json.sh0.maxs[3] - json.sh0.mins[3]) * (i / 255), - ) - .map((x) => 1.0 / (1.0 + Math.exp(-x))); - } - - for (let i = 0; i < numSplats; ++i) { - const i4 = i * 4; - setPackedSplatRgba( - packedArray, - i, - rLookup[sh0[i4 + 0]], - gLookup[sh0[i4 + 1]], - bLookup[sh0[i4 + 2]], - aLookup[sh0[i4 + 3]], - splatEncoding, - ); - } - }, - ); - - const promises = [meansPromise, scalesPromise, quatsPromise, sh0Promise]; - if (json.shN) { - const useSH3 = isVersion2 - ? json.shN.bands >= 3 - : json.shN.shape[1] >= 48 - 3; - const useSH2 = isVersion2 - ? json.shN.bands >= 2 - : json.shN.shape[1] >= 27 - 3; - const useSH1 = isVersion2 - ? json.shN.bands >= 1 - : json.shN.shape[1] >= 12 - 3; - - if (useSH1) extra.sh1 = new Uint32Array(numSplats * 2); - if (useSH2) extra.sh2 = new Uint32Array(numSplats * 4); - if (useSH3) extra.sh3 = new Uint32Array(numSplats * 4); - - const sh1 = new Float32Array(9); - const sh2 = new Float32Array(15); - const sh3 = new Float32Array(21); - - const shN = json.shN; - const shNPromise = Promise.all([ - decodeImage(extraFiles[json.shN.files[0]]), - decodeImage(extraFiles[json.shN.files[1]]), - ]).then(([centroids, labels]) => { - const lookup = - "codebook" in shN - ? shN.codebook - : new Array(256) - .fill(0) - .map((_, i) => shN.mins + (shN.maxs - shN.mins) * (i / 255)); - - for (let i = 0; i < numSplats; ++i) { - const i4 = i * 4; - const label = labels.rgba[i4 + 0] + (labels.rgba[i4 + 1] << 8); - const col = (label & 63) * 15; - const row = label >>> 6; - const offset = row * centroids.width + col; - - for (let d = 0; d < 3; ++d) { - if (useSH1) { - for (let k = 0; k < 3; ++k) { - sh1[k * 3 + d] = lookup[centroids.rgba[(offset + k) * 4 + d]]; - } - } - - if (useSH2) { - for (let k = 0; k < 5; ++k) { - sh2[k * 3 + d] = lookup[centroids.rgba[(offset + 3 + k) * 4 + d]]; - } - } - - if (useSH3) { - for (let k = 0; k < 7; ++k) { - sh3[k * 3 + d] = lookup[centroids.rgba[(offset + 8 + k) * 4 + d]]; - } - } - } - - if (useSH1) - encodeSh1Rgb(extra.sh1 as Uint32Array, i, sh1, splatEncoding); - if (useSH2) - encodeSh2Rgb(extra.sh2 as Uint32Array, i, sh2, splatEncoding); - if (useSH3) - encodeSh3Rgb(extra.sh3 as Uint32Array, i, sh3, splatEncoding); - } - }); - promises.push(shNPromise); - } - - await Promise.all(promises); - - return { packedArray, numSplats, extra }; -} - -// WebGL context for reading raw pixel data of WebP images -let offscreenGlContext: WebGL2RenderingContext | null = null; - -async function decodeImage(fileBytes: ArrayBuffer) { - if (!offscreenGlContext) { - const canvas = new OffscreenCanvas(1, 1); - offscreenGlContext = canvas.getContext("webgl2"); - if (!offscreenGlContext) { - throw new Error("Failed to create WebGL2 context"); - } - } - - const imageBlob = new Blob([fileBytes]); - const bitmap = await createImageBitmap(imageBlob, { - premultiplyAlpha: "none", - }); - - const gl = offscreenGlContext; - const texture = gl.createTexture(); - gl.bindTexture(gl.TEXTURE_2D, texture); - gl.pixelStorei(gl.UNPACK_FLIP_Y_WEBGL, true); - gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, bitmap); - gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST); - gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST); - - const framebuffer = gl.createFramebuffer(); - gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); - gl.framebufferTexture2D( - gl.FRAMEBUFFER, - gl.COLOR_ATTACHMENT0, - gl.TEXTURE_2D, - texture, - 0, - ); - - const data = new Uint8Array(bitmap.width * bitmap.height * 4); - gl.readPixels( - 0, - 0, - bitmap.width, - bitmap.height, - gl.RGBA, - gl.UNSIGNED_BYTE, - data, - ); - - gl.deleteTexture(texture); - gl.deleteFramebuffer(framebuffer); - - return { rgba: data, width: bitmap.width, height: bitmap.height }; -} - -async function decodeImageRgba(fileBytes: ArrayBuffer) { - const { rgba } = await decodeImage(fileBytes); - return rgba; -} - -export async function unpackPcSogsZip( - fileBytes: Uint8Array, - splatEncoding: SplatEncoding, -): Promise<{ - packedArray: Uint32Array; - numSplats: number; - extra: Record; -}> { - const nameJson = tryPcSogsZip(fileBytes); - if (!nameJson) { - throw new Error("Invalid PC SOGS zip file"); - } - const { name, json } = nameJson; - // Find path prefix, will be -1 if no / or \ - const lastSlash = name.lastIndexOf("/"); - const lastBackslash = name.lastIndexOf("\\"); - const prefix = name.slice(0, Math.max(lastSlash, lastBackslash) + 1); - - const fileMap = new Map(); - const refFiles = [ - ...json.means.files, - ...json.scales.files, - ...json.quats.files, - ...json.sh0.files, - ...(json.shN?.files ?? []), - ]; - for (const file of refFiles) { - fileMap.set(prefix + file, file); - } - - const unzipped = await new Promise>( - (resolve, reject) => { - unzip( - fileBytes, - { - filter: ({ name }) => { - return fileMap.has(name); - }, - }, - (err, files) => { - if (err) { - reject(err); - } else { - resolve(files); - } - }, - ); - }, - ); - - const extraFiles: Record = {}; - for (const [full, name] of fileMap.entries()) { - extraFiles[name] = unzipped[full]; - } - - return await unpackPcSogs(json, extraFiles, splatEncoding); -} diff --git a/src/ply.ts b/src/ply.ts deleted file mode 100644 index b3ce24bd..00000000 --- a/src/ply.ts +++ /dev/null @@ -1,1091 +0,0 @@ -// PLY file format reader - -import { USE_COMPILED_PARSER_FUNCTION } from "./defines"; - -const PLY_PROPERTY_TYPES = [ - "char", - "uchar", - "short", - "ushort", - "int", - "uint", - "float", - "double", -] as const; -export type PlyPropertyType = (typeof PLY_PROPERTY_TYPES)[number]; - -export type PlyElement = { - name: string; - count: number; - properties: Record; -}; - -export type PlyProperty = { - isList: boolean; - type: PlyPropertyType; - countType?: PlyPropertyType; -}; - -// Callback for parseSplats base Gsplat data -export type SplatCallback = ( - index: number, - x: number, - y: number, - z: number, - scaleX: number, - scaleY: number, - scaleZ: number, - quatX: number, - quatY: number, - quatZ: number, - quatW: number, - opacity: number, - r: number, - g: number, - b: number, -) => void; - -// Callback for parseSplats SH coefficients -export type SplatShCallback = ( - index: number, - sh1: Float32Array, - sh2?: Float32Array, - sh3?: Float32Array, -) => void; - -// A PlyReader is used to parse PLY files for Gsplat data. -// It takes a Uint8Array/ArrayBuffer as input fileBytes, parses the text header, -// and provides a method parseData to iterate over the entire binary data -// efficiently, or parseSplats to iterate over Gsplat data. - -export class PlyReader { - fileBytes: Uint8Array; - header = ""; - littleEndian = true; - elements: Record = {}; - comments: string[] = []; - data: DataView | null = null; - static defaultPointScale = 0.001; - - numSplats = 0; - - // Create a PlyReader from a Uint8Array/ArrayBuffer, no parsing done yet - constructor({ fileBytes }: { fileBytes: Uint8Array | ArrayBuffer }) { - this.fileBytes = - fileBytes instanceof ArrayBuffer ? new Uint8Array(fileBytes) : fileBytes; - } - - // Identify and parse the PLY text header (assumed to be <64KB in size). - // this.elements will contain all the elements in the file, typically - // "vertex" contains the Gsplat data. - async parseHeader() { - const bufferStream = new ReadableStream({ - start: ( - controller: ReadableStreamController>, - ) => { - // Assume the header is less than 64KB - controller.enqueue(this.fileBytes.slice(0, 65536)); - controller.close(); - }, - }); - const decoder = bufferStream - .pipeThrough(new TextDecoderStream()) - .getReader(); - - // Find the end of the text section of the PLY file - this.header = ""; - const headerTerminator = "end_header\n"; - while (true) { - const { value, done } = await decoder.read(); - if (done) { - throw new Error("Failed to read header"); - } - - this.header += value as string; - const endHeader = this.header.indexOf(headerTerminator); - if (endHeader >= 0) { - this.header = this.header.slice(0, endHeader + headerTerminator.length); - break; - } - } - // Partition the file into header and binary data - const headerLen = new TextEncoder().encode(this.header).length; - this.data = new DataView(this.fileBytes.buffer, headerLen); - - this.elements = {}; - let curElement: PlyElement | null = null; - this.comments = []; - - this.header - .trim() - .split("\n") - .forEach((line: string, lineIndex: number) => { - const trimmedLine = line.trim(); - if (lineIndex === 0) { - if (trimmedLine !== "ply") { - throw new Error("Invalid PLY header"); - } - return; - } - if (trimmedLine.length === 0) { - return; // Skip empty lines - } - - const fields = trimmedLine.split(" "); - switch (fields[0]) { - case "format": - if (fields[1] === "binary_little_endian") { - this.littleEndian = true; - } else if (fields[1] === "binary_big_endian") { - this.littleEndian = false; - } else { - // ascii formats not supported - throw new Error(`Unsupported PLY format: ${fields[1]}`); - } - if (fields[2] !== "1.0") { - throw new Error(`Unsupported PLY version: ${fields[2]}`); - } - break; - case "end_header": - break; - case "comment": - this.comments.push(trimmedLine.slice("comment ".length)); - break; - case "element": { - const name = fields[1]; - curElement = { - name, - count: Number.parseInt(fields[2]), - properties: {}, - }; - this.elements[name] = curElement; - break; - } - case "property": - if (curElement == null) { - throw new Error("Property must be inside an element"); - } - if (fields[1] === "list") { - curElement.properties[fields[4]] = { - isList: true, - type: fields[3] as PlyPropertyType, - countType: fields[2] as PlyPropertyType, - }; - } else { - curElement.properties[fields[2]] = { - isList: false, - type: fields[1] as PlyPropertyType, - }; - } - break; - default: - // console.warn(`Skipping unsupported PLY keyword: ${fields[0]}`); - } - }); - - if (this.elements.vertex) { - this.numSplats = this.elements.vertex.count; - } - } - - parseData( - elementCallback: ( - element: PlyElement, - ) => - | null - | ((index: number, item: Record) => void), - ) { - // Go through the entire binary data of the PLY file, starting at offset 0 - let offset = 0; - const data = this.data; - if (data == null) { - throw new Error("No data to parse"); - } - - for (const elementName in this.elements) { - const element = this.elements[elementName]; - const { count, properties } = element; - const item = createEmptyItem(properties); - // Construct a parse function - const parseFn = createParseFn(properties, this.littleEndian); - - // Parse all the items in the element - const callback = elementCallback(element) ?? (() => {}); - for (let index = 0; index < count; index++) { - offset = parseFn(data, offset, item); - callback(index, item); - } - } - } - - // Parse all the Gsplat data in the PLY file in go, invoking the given - // callbacks for each Gsplat. - parseSplats(splatCallback: SplatCallback, shCallback?: SplatShCallback) { - if (this.elements.vertex == null) { - throw new Error("No vertex element found"); - } - - let isSuperSplat = false; - const ssChunks: SSChunk[] = []; - - let numSh = 0; - let sh1Props: number[] = []; - let sh2Props: number[] = []; - let sh3Props: number[] = []; - let sh1: Float32Array | undefined = undefined; - let sh2: Float32Array | undefined = undefined; - let sh3: Float32Array | undefined = undefined; - - function prepareSh() { - // Prepare SH coefficient names and arrays for numSh total SH levels - const num_f_rest = NUM_SH_TO_NUM_F_REST[numSh]; - sh1Props = new Array(3) - .fill(null) - .flatMap((_, k) => [0, 1, 2].map((_, d) => k + (d * num_f_rest) / 3)); - sh2Props = new Array(5) - .fill(null) - .flatMap((_, k) => - [0, 1, 2].map((_, d) => 3 + k + (d * num_f_rest) / 3), - ); - sh3Props = new Array(7) - .fill(null) - .flatMap((_, k) => - [0, 1, 2].map((_, d) => 8 + k + (d * num_f_rest) / 3), - ); - sh1 = numSh >= 1 ? new Float32Array(3 * 3) : undefined; - sh2 = numSh >= 2 ? new Float32Array(5 * 3) : undefined; - sh3 = numSh >= 3 ? new Float32Array(7 * 3) : undefined; - } - - function ssShCallback( - index: number, - item: Record, - ) { - // Decode SH for SuperSplat compressed data - if (!sh1) { - throw new Error("Missing sh1"); - } - const sh = item.f_rest as number[]; - - for (let i = 0; i < sh1Props.length; i++) { - sh1[i] = (sh[sh1Props[i]] * 8) / 255 - 4; - } - if (sh2) { - for (let i = 0; i < sh2Props.length; i++) { - sh2[i] = (sh[sh2Props[i]] * 8) / 255 - 4; - } - } - if (sh3) { - for (let i = 0; i < sh3Props.length; i++) { - sh3[i] = (sh[sh3Props[i]] * 8) / 255 - 4; - } - } - shCallback?.(index, sh1, sh2, sh3); - } - - function initSuperSplat(element: PlyElement) { - const { - min_x, - min_y, - min_z, - max_x, - max_y, - max_z, - min_scale_x, - min_scale_y, - min_scale_z, - max_scale_x, - max_scale_y, - max_scale_z, - } = element.properties; - if ( - !min_x || - !min_y || - !min_z || - !max_x || - !max_y || - !max_z || - !min_scale_x || - !min_scale_y || - !min_scale_z || - !max_scale_x || - !max_scale_y || - !max_scale_z - ) { - throw new Error("Missing PLY chunk properties"); - } - - // SuperSplat chunks are used to quantize splat data, so we need to store them - isSuperSplat = true; - return (index: number, item: Record) => { - const { - min_x, - min_y, - min_z, - max_x, - max_y, - max_z, - min_scale_x, - min_scale_y, - min_scale_z, - max_scale_x, - max_scale_y, - max_scale_z, - min_r, - min_g, - min_b, - max_r, - max_g, - max_b, - } = item as Record; - ssChunks.push({ - min_x, - min_y, - min_z, - max_x, - max_y, - max_z, - min_scale_x, - min_scale_y, - min_scale_z, - max_scale_x, - max_scale_y, - max_scale_z, - min_r, - min_g, - min_b, - max_r, - max_g, - max_b, - }); - }; - } - - function decodeSuperSplat(element: PlyElement) { - // Decode SuperSplat compressed data in vertex and sh elements - if (shCallback && element.name === "sh") { - numSh = getNumSh(element.properties); - prepareSh(); - return ssShCallback; - } - if (element.name !== "vertex") { - return null; - } - - const { packed_position, packed_rotation, packed_scale, packed_color } = - element.properties; - if ( - !packed_position || - !packed_rotation || - !packed_scale || - !packed_color - ) { - throw new Error( - "Missing PLY properties: packed_position, packed_rotation, packed_scale, packed_color", - ); - } - - const SQRT2 = Math.sqrt(2); - - return (index: number, item: Record) => { - // SuperSplat data are quantized within chunks with 256 Gsplats each - const chunk = ssChunks[index >>> 8]; - if (chunk == null) { - throw new Error("Missing PLY chunk"); - } - const { - min_x, - min_y, - min_z, - max_x, - max_y, - max_z, - min_scale_x, - min_scale_y, - min_scale_z, - max_scale_x, - max_scale_y, - max_scale_z, - min_r, - min_g, - min_b, - max_r, - max_g, - max_b, - } = chunk; - const { packed_position, packed_rotation, packed_scale, packed_color } = - item as Record; - - const x = - (((packed_position >>> 21) & 2047) / 2047) * (max_x - min_x) + min_x; - const y = - (((packed_position >>> 11) & 1023) / 1023) * (max_y - min_y) + min_y; - const z = ((packed_position & 2047) / 2047) * (max_z - min_z) + min_z; - - const r0 = (((packed_rotation >>> 20) & 1023) / 1023 - 0.5) * SQRT2; - const r1 = (((packed_rotation >>> 10) & 1023) / 1023 - 0.5) * SQRT2; - const r2 = ((packed_rotation & 1023) / 1023 - 0.5) * SQRT2; - const rr = Math.sqrt(Math.max(0, 1.0 - r0 * r0 - r1 * r1 - r2 * r2)); - - const rOrder = packed_rotation >>> 30; - const quatX = rOrder === 0 ? r0 : rOrder === 1 ? rr : r1; - const quatY = rOrder <= 1 ? r1 : rOrder === 2 ? rr : r2; - const quatZ = rOrder <= 2 ? r2 : rr; - const quatW = rOrder === 0 ? rr : r0; - - const scaleX = Math.exp( - (((packed_scale >>> 21) & 2047) / 2047) * - (max_scale_x - min_scale_x) + - min_scale_x, - ); - const scaleY = Math.exp( - (((packed_scale >>> 11) & 1023) / 1023) * - (max_scale_y - min_scale_y) + - min_scale_y, - ); - const scaleZ = Math.exp( - ((packed_scale & 2047) / 2047) * (max_scale_z - min_scale_z) + - min_scale_z, - ); - - const r = - (((packed_color >>> 24) & 255) / 255) * - ((max_r ?? 1) - (min_r ?? 0)) + - (min_r ?? 0); - const g = - (((packed_color >>> 16) & 255) / 255) * - ((max_g ?? 1) - (min_g ?? 0)) + - (min_g ?? 0); - const b = - (((packed_color >>> 8) & 255) / 255) * ((max_b ?? 1) - (min_b ?? 0)) + - (min_b ?? 0); - const opacity = (packed_color & 255) / 255; - - splatCallback( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ); - }; - } - - const elementCallback = (element: PlyElement) => { - if (element.name === "chunk") { - // "chunk" could conceivably be used for other formats, and we would - // ideally check for the comment: Generated by SuperSplat 2.* - // but gsplat also outputs this format without such a comment. - // In order to support both, let's assume a "chunk" element should - // be interpreted as this format. - return initSuperSplat(element); - } - if (isSuperSplat) { - return decodeSuperSplat(element); - } - - if (element.name !== "vertex") { - return null; - } - - const { - x, - y, - z, - scale_0, - scale_1, - scale_2, - rot_0, - rot_1, - rot_2, - rot_3, - opacity, - f_dc_0, - f_dc_1, - f_dc_2, - red, - green, - blue, - alpha, - } = element.properties; - - if (!x || !y || !z) { - throw new Error("Missing PLY properties: x, y, z"); - } - // Pure point cloud PLY files have no scales or rotations - const hasScales = scale_0 && scale_1 && scale_2; - const hasRots = rot_0 && rot_1 && rot_2 && rot_3; - // Quantization scale factor for argb values - const alphaDiv = alpha != null ? FIELD_SCALE[alpha.type] : 1; - const redDiv = red != null ? FIELD_SCALE[red.type] : 1; - const greenDiv = green != null ? FIELD_SCALE[green.type] : 1; - const blueDiv = blue != null ? FIELD_SCALE[blue.type] : 1; - - numSh = getNumSh(element.properties); - prepareSh(); - - return (index: number, item: Record) => { - const scaleX = hasScales - ? Math.exp(item.scale_0 as number) - : PlyReader.defaultPointScale; - const scaleY = hasScales - ? Math.exp(item.scale_1 as number) - : PlyReader.defaultPointScale; - const scaleZ = hasScales - ? Math.exp(item.scale_2 as number) - : PlyReader.defaultPointScale; - - const quatX = hasRots ? (item.rot_1 as number) : 0; - const quatY = hasRots ? (item.rot_2 as number) : 0; - const quatZ = hasRots ? (item.rot_3 as number) : 0; - const quatW = hasRots ? (item.rot_0 as number) : 1; - - const op = - opacity != null - ? 1.0 / (1.0 + Math.exp(-item.opacity as number)) - : alpha != null - ? (item.alpha as number) / alphaDiv - : 1.0; - const r = - f_dc_0 != null - ? (item.f_dc_0 as number) * SH_C0 + 0.5 - : red != null - ? (item.red as number) / redDiv - : 1.0; - const g = - f_dc_1 != null - ? (item.f_dc_1 as number) * SH_C0 + 0.5 - : green != null - ? (item.green as number) / greenDiv - : 1.0; - const b = - f_dc_2 != null - ? (item.f_dc_2 as number) * SH_C0 + 0.5 - : blue != null - ? (item.blue as number) / blueDiv - : 1.0; - - splatCallback( - index, - item.x as number, - item.y as number, - item.z as number, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - op, - r, - g, - b, - ); - - if (shCallback && sh1) { - const sh = item.f_rest as number[]; - if (sh1) { - for (let i = 0; i < sh1Props.length; i++) { - sh1[i] = sh[sh1Props[i]]; - } - } - if (sh2) { - for (let i = 0; i < sh2Props.length; i++) { - sh2[i] = sh[sh2Props[i]]; - } - } - if (sh3) { - for (let i = 0; i < sh3Props.length; i++) { - sh3[i] = sh[sh3Props[i]]; - } - } - shCallback(index, sh1, sh2, sh3); - } - }; - }; - - this.parseData(elementCallback); - } - - // Inject RGBA values into original PLY file, which can be used to modify - // the color/opacity of the Gsplats and write out the modified PLY file. - injectRgba(rgba: Uint8Array) { - // Go through the entire binary data of the PLY file, starting at offset 0 - let offset = 0; - const data = this.data; - if (data == null) { - throw new Error("No parsed data"); - } - if (rgba.length !== this.numSplats * 4) { - throw new Error("Invalid RGBA array length"); - } - - for (const elementName in this.elements) { - const element = this.elements[elementName]; - const { count, properties } = element; - const parsers = []; - - let rgbaOffset = 0; - const isVertex = elementName === "vertex"; - if (isVertex) { - for (const name of ["opacity", "f_dc_0", "f_dc_1", "f_dc_2"]) { - if (!properties[name] || properties[name].type !== "float") { - throw new Error(`Can't injectRgba due to property: ${name}`); - } - } - } - - for (const [propertyName, property] of Object.entries(properties)) { - if (!property.isList) { - if (isVertex) { - if ( - propertyName === "f_dc_0" || - propertyName === "f_dc_1" || - propertyName === "f_dc_2" - ) { - const component = Number.parseInt( - propertyName.slice("f_dc_".length), - ); - parsers.push(() => { - // Inject DC coefficients - const value = - (rgba[rgbaOffset + component] / 255 - 0.5) / SH_C0; - SET_FIELD[property.type]( - data, - offset, - this.littleEndian, - value, - ); - }); - } else if (propertyName === "opacity") { - parsers.push(() => { - // Inject opacity sigmoid, clamped to [-100, 100] - const value = Math.max( - -100, - Math.min( - 100, - -Math.log(1.0 / (rgba[rgbaOffset + 3] / 255) - 1.0), - ), - ); - SET_FIELD[property.type]( - data, - offset, - this.littleEndian, - value, - ); - }); - } - } - parsers.push(() => { - offset += FIELD_BYTES[property.type]; - }); - } else { - parsers.push(() => { - const length = PARSE_FIELD[property.countType as PlyPropertyType]( - data, - offset, - this.littleEndian, - ); - offset += FIELD_BYTES[property.countType as PlyPropertyType]; - offset += length * FIELD_BYTES[property.type]; - }); - } - } - - for (let index = 0; index < count; index++) { - // Go through all the data and field parsers to compute offset - for (const parser of parsers) { - parser(); - } - if (isVertex) { - rgbaOffset += 4; - } - } - } - } -} - -export const SH_C0 = 0.28209479177387814; - -type FieldParser = ( - data: DataView, - offset: number, - littleEndian: boolean, -) => number; -type FieldSetter = ( - data: DataView, - offset: number, - littleEndian: boolean, - value: number, -) => void; - -const PARSE_FIELD: Record = { - char: (data: DataView, offset: number, littleEndian: boolean) => { - return data.getInt8(offset); - }, - uchar: (data: DataView, offset: number, littleEndian: boolean) => { - return data.getUint8(offset); - }, - short: (data: DataView, offset: number, littleEndian: boolean) => { - return data.getInt16(offset, littleEndian); - }, - ushort: (data: DataView, offset: number, littleEndian: boolean) => { - return data.getUint16(offset, littleEndian); - }, - int: (data: DataView, offset: number, littleEndian: boolean) => { - return data.getInt32(offset, littleEndian); - }, - uint: (data: DataView, offset: number, littleEndian: boolean) => { - return data.getUint32(offset, littleEndian); - }, - float: (data: DataView, offset: number, littleEndian: boolean) => { - return data.getFloat32(offset, littleEndian); - }, - double: (data: DataView, offset: number, littleEndian: boolean) => { - return data.getFloat64(offset, littleEndian); - }, -}; - -const SET_FIELD: Record = { - char: ( - data: DataView, - offset: number, - littleEndian: boolean, - value: number, - ) => { - data.setInt8(offset, value); - }, - uchar: ( - data: DataView, - offset: number, - littleEndian: boolean, - value: number, - ) => { - data.setUint8(offset, value); - }, - short: ( - data: DataView, - offset: number, - littleEndian: boolean, - value: number, - ) => { - data.setInt16(offset, value, littleEndian); - }, - ushort: ( - data: DataView, - offset: number, - littleEndian: boolean, - value: number, - ) => { - data.setUint16(offset, value, littleEndian); - }, - int: ( - data: DataView, - offset: number, - littleEndian: boolean, - value: number, - ) => { - data.setInt32(offset, value, littleEndian); - }, - uint: ( - data: DataView, - offset: number, - littleEndian: boolean, - value: number, - ) => { - data.setUint32(offset, value, littleEndian); - }, - float: ( - data: DataView, - offset: number, - littleEndian: boolean, - value: number, - ) => { - data.setFloat32(offset, value, littleEndian); - }, - double: ( - data: DataView, - offset: number, - littleEndian: boolean, - value: number, - ) => { - data.setFloat64(offset, value, littleEndian); - }, -}; - -const FIELD_BYTES: Record = { - char: 1, - uchar: 1, - short: 2, - ushort: 2, - int: 4, - uint: 4, - float: 4, - double: 8, -}; - -const FIELD_SCALE: Record = { - char: 127, - uchar: 255, - short: 32767, - ushort: 65535, - int: 2147483647, - uint: 4294967295, - float: 1, - double: 1, -}; - -const NUM_F_REST_TO_NUM_SH: Record = { - 0: 0, - 9: 1, - 24: 2, - 45: 3, -}; -const NUM_SH_TO_NUM_F_REST: Record = { - 0: 0, - 1: 9, - 2: 24, - 3: 45, -}; - -const F_REST_REGEX = /^f_rest_([0-9]{1,2})$/; - -function createEmptyItem( - properties: Record, -): Record { - const item: Record = {}; - for (const [propertyName, property] of Object.entries(properties)) { - // Treat f_rest properties as a single array for performance - if (F_REST_REGEX.test(propertyName)) { - item.f_rest = new Array(getNumSh(properties)); - } else { - item[propertyName] = property.isList ? [] : 0; - } - } - return item; -} - -function createParseFn( - properties: Record, - littleEndian: boolean, -) { - if (USE_COMPILED_PARSER_FUNCTION && safeToCompile(properties)) { - return createCompiledParserFn(properties, littleEndian); - } - return createDynamicParserFn(properties, littleEndian); -} - -// Detect if unsafe eval is allowed in the current execution context -const UNSAFE_EVAL_ALLOWED = (() => { - try { - new Function("return 42;"); - } catch (e) { - return false; - } - return true; -})(); -const PROPERTY_NAME_REGEX = /^[a-zA-Z0-9_]+$/; - -function safeToCompile(properties: Record) { - if (!UNSAFE_EVAL_ALLOWED) { - return false; - } - - for (const [propertyName, property] of Object.entries(properties)) { - if (!PROPERTY_NAME_REGEX.test(propertyName)) { - return false; - } - - if ( - property.isList && - !PLY_PROPERTY_TYPES.includes(property.countType as PlyPropertyType) - ) { - return false; - } - - if (!PLY_PROPERTY_TYPES.includes(property.type)) { - return false; - } - } - return true; -} - -function createCompiledParserFn( - properties: Record, - littleEndian: boolean, -) { - // Construct the parser function source. - const parserSrc: string[] = ["let list;"]; - for (const [propertyName, property] of Object.entries(properties)) { - const fRestMatch = propertyName.match(F_REST_REGEX); - if (fRestMatch) { - const fRestIndex = +fRestMatch[1]; - parserSrc.push(/*js*/ ` - item.f_rest[${fRestIndex}] = PARSE_FIELD['${property.type}'](data, offset, ${littleEndian}); - offset += ${FIELD_BYTES[property.type]}; - `); - } else if (!property.isList) { - parserSrc.push(/*js*/ ` - item['${propertyName}'] = PARSE_FIELD['${property.type}'](data, offset, ${littleEndian}); - offset += ${FIELD_BYTES[property.type]}; - `); - } else { - // Property is a list, so parse the count first - parserSrc.push(/*js*/ ` - list = item['${propertyName}']; - list.length = PARSE_FIELD['${property.countType}'](data, offset, ${littleEndian}); - offset += ${FIELD_BYTES[property.countType as PlyPropertyType]}; - for (let i = 0; i < list.length; i++) { - list[i] = PARSE_FIELD['${property.type}'](data, offset, ${littleEndian}); - offset += ${FIELD_BYTES[property.type]}; - } - `); - } - } - parserSrc.push("return offset;"); - - const fn = new Function( - "data", - "offset", - "item", - "PARSE_FIELD", - parserSrc.join("\n"), - ); - return ( - data: DataView, - offset: number, - item: Record, - ) => fn(data, offset, item, PARSE_FIELD); -} - -function createDynamicParserFn( - properties: Record, - littleEndian: boolean, -) { - // Construct an array of parser function to parse each property in an item - const parsers: Array< - ( - data: DataView, - offset: number, - item: Record, - ) => number - > = []; - for (const [propertyName, property] of Object.entries(properties)) { - const fRestMatch = propertyName.match(F_REST_REGEX); - if (fRestMatch) { - const fRestIndex = +fRestMatch[1]; - parsers.push( - ( - data: DataView, - offset: number, - item: Record, - ) => { - (item.f_rest as number[])[fRestIndex] = PARSE_FIELD[property.type]( - data, - offset, - littleEndian, - ); - return offset + FIELD_BYTES[property.type]; - }, - ); - } else if (!property.isList) { - parsers.push( - ( - data: DataView, - offset: number, - item: Record, - ) => { - item[propertyName] = PARSE_FIELD[property.type]( - data, - offset, - littleEndian, - ); - return offset + FIELD_BYTES[property.type]; - }, - ); - } else { - // Property is a list, so parse the count first - parsers.push( - ( - data: DataView, - offset: number, - item: Record, - ) => { - const list = item[propertyName] as number[]; - list.length = PARSE_FIELD[property.countType as PlyPropertyType]( - data, - offset, - littleEndian, - ); - let currentOffset = - offset + FIELD_BYTES[property.countType as PlyPropertyType]; - for (let i = 0; i < list.length; i++) { - list[i] = PARSE_FIELD[property.type]( - data, - currentOffset, - littleEndian, - ); - currentOffset += FIELD_BYTES[property.type]; - } - return currentOffset; - }, - ); - } - } - - return ( - data: DataView, - offset: number, - item: Record, - ) => { - let currentOffset = offset; - for (let parserIndex = 0; parserIndex < parsers.length; parserIndex++) { - currentOffset = parsers[parserIndex](data, currentOffset, item); - } - return currentOffset; - }; -} - -function getNumSh(properties: Record) { - let num_f_rest = 0; - while (properties[`f_rest_${num_f_rest}`]) { - num_f_rest += 1; - } - const numSh = NUM_F_REST_TO_NUM_SH[num_f_rest]; - if (numSh == null) { - throw new Error(`Unsupported number of SH coefficients: ${num_f_rest}`); - } - return numSh; -} - -type SSChunk = { - min_x: number; - min_y: number; - min_z: number; - max_x: number; - max_y: number; - max_z: number; - min_scale_x: number; - min_scale_y: number; - min_scale_z: number; - max_scale_x: number; - max_scale_y: number; - max_scale_z: number; - min_r?: number; - min_g?: number; - min_b?: number; - max_r?: number; - max_g?: number; - max_b?: number; -}; diff --git a/src/spz.ts b/src/spz.ts index 75d3d832..ac898cbe 100644 --- a/src/spz.ts +++ b/src/spz.ts @@ -1,519 +1,14 @@ -import * as THREE from "three"; +import type { PackedSplats } from "./PackedSplats"; import { - SplatData, type TranscodeSpzInput, getSplatFileType, getSplatFileTypeFromPath, } from "./SplatLoader"; -import { GunzipReader, fromHalf, normalize } from "./utils"; -import { decodeAntiSplat } from "./antisplat"; -import { SplatFileType } from "./defines"; -import { decodeKsplat } from "./ksplat"; -import { PlyReader } from "./ply"; - -// SPZ file format reader - -export class SpzReader { - fileBytes: Uint8Array; - reader: GunzipReader; - - version = -1; - numSplats = 0; - shDegree = 0; - fractionalBits = 0; - flags = 0; - flagAntiAlias = false; - flagLod = false; - reserved = 0; - headerParsed = false; - parsed = false; - - constructor({ fileBytes }: { fileBytes: Uint8Array | ArrayBuffer }) { - this.fileBytes = - fileBytes instanceof ArrayBuffer ? new Uint8Array(fileBytes) : fileBytes; - this.reader = new GunzipReader({ - fileBytes: this.fileBytes as Uint8Array, - }); - } - - async parseHeader() { - if (this.headerParsed) { - throw new Error("SPZ file header already parsed"); - } - - const header = new DataView((await this.reader.read(16)).buffer); - if (header.getUint32(0, true) !== 0x5053474e) { - throw new Error("Invalid SPZ file"); - } - this.version = header.getUint32(4, true); - if (this.version < 1 || this.version > 3) { - throw new Error(`Unsupported SPZ version: ${this.version}`); - } - - this.numSplats = header.getUint32(8, true); - this.shDegree = header.getUint8(12); - this.fractionalBits = header.getUint8(13); - this.flags = header.getUint8(14); - this.flagAntiAlias = (this.flags & 0x01) !== 0; - this.flagLod = (this.flags & 0x80) !== 0; - this.reserved = header.getUint8(15); - this.headerParsed = true; - this.parsed = false; - } - - async parseSplats( - centerCallback?: (index: number, x: number, y: number, z: number) => void, - alphaCallback?: (index: number, alpha: number) => void, - rgbCallback?: (index: number, r: number, g: number, b: number) => void, - scalesCallback?: ( - index: number, - scaleX: number, - scaleY: number, - scaleZ: number, - ) => void, - quatCallback?: ( - index: number, - quatX: number, - quatY: number, - quatZ: number, - quatW: number, - ) => void, - shCallback?: ( - index: number, - sh1: Float32Array, - sh2?: Float32Array, - sh3?: Float32Array, - ) => void, - { - childCounts, - childStarts, - }: { - childCounts?: (index: number, count: number) => void; - childStarts?: (index: number, start: number) => void; - } = {}, - ) { - if (!this.headerParsed) { - throw new Error("SPZ file header must be parsed first"); - } - if (this.parsed) { - throw new Error("SPZ file already parsed"); - } - this.parsed = true; - - if (this.version === 1) { - // float16 centers - const centerBytes = await this.reader.read(this.numSplats * 3 * 2); - const centerUint16 = new Uint16Array(centerBytes.buffer); - for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 3; - const x = fromHalf(centerUint16[i3]); - const y = fromHalf(centerUint16[i3 + 1]); - const z = fromHalf(centerUint16[i3 + 2]); - centerCallback?.(i, x, y, z); - } - } else if (this.version === 2 || this.version === 3) { - // 24-bit fixed-point centers - const fixed = 1 << this.fractionalBits; - const centerBytes = await this.reader.read(this.numSplats * 3 * 3); - for (let i = 0; i < this.numSplats; i++) { - const i9 = i * 9; - const x = - (((centerBytes[i9 + 2] << 24) | - (centerBytes[i9 + 1] << 16) | - (centerBytes[i9] << 8)) >> - 8) / - fixed; - const y = - (((centerBytes[i9 + 5] << 24) | - (centerBytes[i9 + 4] << 16) | - (centerBytes[i9 + 3] << 8)) >> - 8) / - fixed; - const z = - (((centerBytes[i9 + 8] << 24) | - (centerBytes[i9 + 7] << 16) | - (centerBytes[i9 + 6] << 8)) >> - 8) / - fixed; - centerCallback?.(i, x, y, z); - } - } else { - throw new Error("Unreachable"); - } - - { - const bytes = await this.reader.read(this.numSplats); - for (let i = 0; i < this.numSplats; i++) { - alphaCallback?.(i, bytes[i] / 255); - } - } - { - const rgbBytes = await this.reader.read(this.numSplats * 3); - const scale = SH_C0 / 0.15; - for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 3; - const r = (rgbBytes[i3] / 255 - 0.5) * scale + 0.5; - const g = (rgbBytes[i3 + 1] / 255 - 0.5) * scale + 0.5; - const b = (rgbBytes[i3 + 2] / 255 - 0.5) * scale + 0.5; - rgbCallback?.(i, r, g, b); - } - } - { - const scalesBytes = await this.reader.read(this.numSplats * 3); - for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 3; - const scaleX = Math.exp(scalesBytes[i3] / 16 - 10); - const scaleY = Math.exp(scalesBytes[i3 + 1] / 16 - 10); - const scaleZ = Math.exp(scalesBytes[i3 + 2] / 16 - 10); - scalesCallback?.(i, scaleX, scaleY, scaleZ); - } - } - if (this.version === 3) { - // Version 3 uses a trick called "smallest three" to compress the rotation quaternions - // achieving better precision. "Optimizing orientation" section at https://gafferongames.com/post/snapshot_compression/ A quaternion length must be 1: x^2+y^2+z^2+w^2 = 1 - // We can drop one component and reconstruct it with the identity above. - // Largest component is dropped for best numerical precision. - // Quaternion stored in 32 bits - // 10 bits singed integer for each of the 3 components + 2 bits indicating the index of dropped component. - // vs 8 bits for each component uncompressed (spz version < 3) - // Max Value after extracting largest component v is another component v - // (v,v,0,0) - // v^2 + v^2 = 1 - // v = 1 / sqrt(2); - const maxValue = 1 / Math.sqrt(2); // 0.7071 - const quatBytes = await this.reader.read(this.numSplats * 4); - for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 4; - const quaternion = [0, 0, 0, 0]; - const values = [ - quatBytes[i3], - quatBytes[i3 + 1], - quatBytes[i3 + 2], - quatBytes[i3 + 3], - ]; - // all values are packed in 32 bits (10 per each of 3 components + 2 bits of index of larged value) - const combinedValues = - values[0] + (values[1] << 8) + (values[2] << 16) + (values[3] << 24); - // each component value is 9 bits + sign (1 bit) - const valueMask = (1 << 9) - 1; - // extract index of the largest element. 2 top bits. - const largestIndex = combinedValues >>> 30; - let remainingValues = combinedValues; - let sumSquares = 0; - - for (let i = 3; i >= 0; --i) { - if (i !== largestIndex) { - // extract current value and sign. - const value = remainingValues & valueMask; - const sign = (remainingValues >>> 9) & 0x1; - // each value is represented as 10 bits. Shift to next one. - remainingValues = remainingValues >>> 10; - // convert to range [0,1] and then to [0, 0.7071] - quaternion[i] = maxValue * (value / valueMask); - // apply sign. - quaternion[i] = sign === 0 ? quaternion[i] : -quaternion[i]; - // accumulate the sum of squares - sumSquares += quaternion[i] * quaternion[i]; - } - } - - // quartenion length must be 1 (x^2+y^2+z^2+w^2 = 1) - // so can reconstruct largest component from the other 3. - // w = sqrt(1 - x^2 - y^2 - z^2); - const square = 1 - sumSquares; - quaternion[largestIndex] = Math.sqrt(Math.max(square, 0)); - - quatCallback?.( - i, - quaternion[0], - quaternion[1], - quaternion[2], - quaternion[3], - ); - } - } else { - const quatBytes = await this.reader.read(this.numSplats * 3); - for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 3; - const quatX = quatBytes[i3] / 127.5 - 1; - const quatY = quatBytes[i3 + 1] / 127.5 - 1; - const quatZ = quatBytes[i3 + 2] / 127.5 - 1; - const quatW = Math.sqrt( - Math.max(0, 1 - quatX * quatX - quatY * quatY - quatZ * quatZ), - ); - quatCallback?.(i, quatX, quatY, quatZ, quatW); - } - } - - if (shCallback && this.shDegree >= 1) { - const sh1 = new Float32Array(3 * 3); - const sh2 = this.shDegree >= 2 ? new Float32Array(5 * 3) : undefined; - const sh3 = this.shDegree >= 3 ? new Float32Array(7 * 3) : undefined; - const shBytes = await this.reader.read( - this.numSplats * SH_DEGREE_TO_VECS[this.shDegree] * 3, - ); - - let offset = 0; - for (let i = 0; i < this.numSplats; i++) { - for (let j = 0; j < 9; ++j) { - sh1[j] = (shBytes[offset + j] - 128) / 128; - } - offset += 9; - if (sh2) { - for (let j = 0; j < 15; ++j) { - sh2[j] = (shBytes[offset + j] - 128) / 128; - } - offset += 15; - } - if (sh3) { - for (let j = 0; j < 21; ++j) { - sh3[j] = (shBytes[offset + j] - 128) / 128; - } - offset += 21; - } - shCallback?.(i, sh1, sh2, sh3); - } - } - if (this.flagLod) { - let bytes = await this.reader.read(this.numSplats * 2); - for (let i = 0; i < this.numSplats; i++) { - const i2 = i * 2; - const count = bytes[i2] + (bytes[i2 + 1] << 8); - childCounts?.(i, count); - } - - bytes = await this.reader.read(this.numSplats * 4); - for (let i = 0; i < this.numSplats; i++) { - const i4 = i * 4; - const start = - bytes[i4] + - (bytes[i4 + 1] << 8) + - (bytes[i4 + 2] << 16) + - (bytes[i4 + 3] << 24); - childStarts?.(i, start); - } - } - } -} - -const SH_DEGREE_TO_VECS: Record = { 1: 3, 2: 8, 3: 15 }; -const SH_C0 = 0.28209479177387814; - -export const SPZ_MAGIC = 0x5053474e; // NGSP = Niantic gaussian splat -export const SPZ_VERSION = 3; -export const FLAG_ANTIALIASED = 0x1; - -export class SpzWriter { - buffer: ArrayBuffer; - view: DataView; - numSplats: number; - shDegree: number; - fractionalBits: number; - fraction: number; - flagAntiAlias: boolean; - clippedCount = 0; - - constructor({ - numSplats, - shDegree, - fractionalBits = 12, - flagAntiAlias = true, - }: { - numSplats: number; - shDegree: number; - fractionalBits?: number; - flagAntiAlias?: boolean; - }) { - const splatSize = - 9 + // Position - 1 + // Opacity - 3 + // Scale - 3 + // DC-rgb - 4 + // Rotation - (shDegree >= 1 ? 9 : 0) + - (shDegree >= 2 ? 15 : 0) + - (shDegree >= 3 ? 21 : 0); - const bufferSize = 16 + numSplats * splatSize; - this.buffer = new ArrayBuffer(bufferSize); - this.view = new DataView(this.buffer); - - this.view.setUint32(0, SPZ_MAGIC, true); // NGSP - this.view.setUint32(4, SPZ_VERSION, true); - this.view.setUint32(8, numSplats, true); - this.view.setUint8(12, shDegree); - this.view.setUint8(13, fractionalBits); - this.view.setUint8(14, flagAntiAlias ? FLAG_ANTIALIASED : 0); - this.view.setUint8(15, 0); // Reserved - - this.numSplats = numSplats; - this.shDegree = shDegree; - this.fractionalBits = fractionalBits; - this.fraction = 1 << fractionalBits; - this.flagAntiAlias = flagAntiAlias; - } - - setCenter(index: number, x: number, y: number, z: number) { - // Divide by this.fraction and round to nearest integer, - // then write as 3-bytes per x then y then z. - const xRounded = Math.round(x * this.fraction); - const xInt = Math.max(-0x7fffff, Math.min(0x7fffff, xRounded)); - const yRounded = Math.round(y * this.fraction); - const yInt = Math.max(-0x7fffff, Math.min(0x7fffff, yRounded)); - const zRounded = Math.round(z * this.fraction); - const zInt = Math.max(-0x7fffff, Math.min(0x7fffff, zRounded)); - const clipped = xRounded !== xInt || yRounded !== yInt || zRounded !== zInt; - if (clipped) { - this.clippedCount += 1; - // if (this.clippedCount < 10) { - // // Write x y z also in hex - // console.log(`Clipped ${index}: ${x}, ${y}, ${z} (0x${x.toString(16)}, 0x${y.toString(16)}, 0x${z.toString(16)}) -> ${xRounded}, ${yRounded}, ${zRounded} (0x${xRounded.toString(16)}, 0x${yRounded.toString(16)}, 0x${zRounded.toString(16)}) -> ${xInt}, ${yInt}, ${zInt} (0x${xInt.toString(16)}, 0x${yInt.toString(16)}, 0x${zInt.toString(16)})`); - // } - } - const i9 = index * 9; - const base = 16 + i9; - this.view.setUint8(base, xInt & 0xff); - this.view.setUint8(base + 1, (xInt >> 8) & 0xff); - this.view.setUint8(base + 2, (xInt >> 16) & 0xff); - this.view.setUint8(base + 3, yInt & 0xff); - this.view.setUint8(base + 4, (yInt >> 8) & 0xff); - this.view.setUint8(base + 5, (yInt >> 16) & 0xff); - this.view.setUint8(base + 6, zInt & 0xff); - this.view.setUint8(base + 7, (zInt >> 8) & 0xff); - this.view.setUint8(base + 8, (zInt >> 16) & 0xff); - } - - setAlpha(index: number, alpha: number) { - const base = 16 + this.numSplats * 9 + index; - this.view.setUint8( - base, - Math.max(0, Math.min(255, Math.round(alpha * 255))), - ); - } - - static scaleRgb(r: number) { - const v = ((r - 0.5) / (SH_C0 / 0.15) + 0.5) * 255; - return Math.max(0, Math.min(255, Math.round(v))); - } - - setRgb(index: number, r: number, g: number, b: number) { - const base = 16 + this.numSplats * 10 + index * 3; - this.view.setUint8(base, SpzWriter.scaleRgb(r)); - this.view.setUint8(base + 1, SpzWriter.scaleRgb(g)); - this.view.setUint8(base + 2, SpzWriter.scaleRgb(b)); - } - - setScale(index: number, scaleX: number, scaleY: number, scaleZ: number) { - const base = 16 + this.numSplats * 13 + index * 3; - this.view.setUint8( - base, - Math.max(0, Math.min(255, Math.round((Math.log(scaleX) + 10) * 16))), - ); - this.view.setUint8( - base + 1, - Math.max(0, Math.min(255, Math.round((Math.log(scaleY) + 10) * 16))), - ); - this.view.setUint8( - base + 2, - Math.max(0, Math.min(255, Math.round((Math.log(scaleZ) + 10) * 16))), - ); - } - - setQuat( - index: number, - ...q: [number, number, number, number] // x, y, z, w - ) { - const base = 16 + this.numSplats * 16 + index * 4; - - const quat = normalize(q); - - // Find largest component - let iLargest = 0; - for (let i = 1; i < 4; ++i) { - if (Math.abs(quat[i]) > Math.abs(quat[iLargest])) { - iLargest = i; - } - } - - // Since -quat represents the same rotation as quat, transform the quaternion so the largest element - // is positive. This avoids having to send its sign bit. - const negate = quat[iLargest] < 0 ? 1 : 0; - - // Do compression using sign bit and 9-bit precision per element. - let comp = iLargest; - for (let i = 0; i < 4; ++i) { - if (i !== iLargest) { - const negbit = (quat[i] < 0 ? 1 : 0) ^ negate; - const mag = Math.floor( - ((1 << 9) - 1) * (Math.abs(quat[i]) / Math.SQRT1_2) + 0.5, - ); - comp = (comp << 10) | (negbit << 9) | mag; - } - } - - this.view.setUint8(base, comp & 0xff); - this.view.setUint8(base + 1, (comp >> 8) & 0xff); - this.view.setUint8(base + 2, (comp >> 16) & 0xff); - this.view.setUint8(base + 3, (comp >>> 24) & 0xff); - } - - static quantizeSh(sh: number, bits: number) { - const value = Math.round(sh * 128) + 128; - const bucketSize = 1 << (8 - bits); - const quantized = - Math.floor((value + bucketSize / 2) / bucketSize) * bucketSize; - return Math.max(0, Math.min(255, quantized)); - } - - setSh( - index: number, - sh1: Float32Array, - sh2?: Float32Array, - sh3?: Float32Array, - ) { - const shVecs = SH_DEGREE_TO_VECS[this.shDegree] || 0; - const base1 = 16 + this.numSplats * 20 + index * shVecs * 3; - for (let j = 0; j < 9; ++j) { - this.view.setUint8(base1 + j, SpzWriter.quantizeSh(sh1[j], 5)); - } - if (sh2) { - const base2 = base1 + 9; - for (let j = 0; j < 15; ++j) { - this.view.setUint8(base2 + j, SpzWriter.quantizeSh(sh2[j], 4)); - } - if (sh3) { - const base3 = base2 + 15; - for (let j = 0; j < 21; ++j) { - this.view.setUint8(base3 + j, SpzWriter.quantizeSh(sh3[j], 4)); - } - } - } - } - - async finalize(): Promise { - const input = new Uint8Array(this.buffer); - const stream = new ReadableStream({ - async start(controller) { - controller.enqueue(input); - controller.close(); - }, - }); - const compressed = stream.pipeThrough(new CompressionStream("gzip")); - const response = new Response(compressed); - const buffer = await response.arrayBuffer(); - console.log( - "Compressed", - input.length, - "bytes to", - buffer.byteLength, - "bytes", - ); - return new Uint8Array(buffer); - } -} +import { decode_to_gsplatarray, packedsplats_to_gsplatarray } from "spark-rs"; export async function transcodeSpz(input: TranscodeSpzInput) { - const splats = new SplatData(); + const splatArrays = []; const { inputs, clipXyz, @@ -523,45 +18,9 @@ export async function transcodeSpz(input: TranscodeSpzInput) { } = input; for (const input of inputs) { const scale = input.transform?.scale ?? 1; - const quaternion = new THREE.Quaternion().fromArray( - input.transform?.quaternion ?? [0, 0, 0, 1], - ); - const translate = new THREE.Vector3().fromArray( - input.transform?.translate ?? [0, 0, 0], - ); - const clip = clipXyz - ? new THREE.Box3( - new THREE.Vector3().fromArray(clipXyz.min), - new THREE.Vector3().fromArray(clipXyz.max), - ) - : undefined; - - function transformPos(pos: THREE.Vector3) { - pos.multiplyScalar(scale); - pos.applyQuaternion(quaternion); - pos.add(translate); - return pos; - } - - function transformScales(scales: THREE.Vector3) { - scales.multiplyScalar(scale); - return scales; - } - - function transformQuaternion(quat: THREE.Quaternion) { - quat.premultiply(quaternion); - return quat; - } - - function withinClip(p: THREE.Vector3) { - return !clip || clip.containsPoint(p); - } - - function withinOpacity(opacity: number) { - return opacityThreshold !== undefined - ? opacity >= opacityThreshold - : true; - } + const quaternion = input.transform?.quaternion ?? [0, 0, 0, 1]; + const translate = input.transform?.translate ?? [0, 0, 0]; + const clip = clipXyz ? [...clipXyz.min, ...clipXyz.max] : undefined; let fileType = input.fileType; if (!fileType) { @@ -570,294 +29,52 @@ export async function transcodeSpz(input: TranscodeSpzInput) { fileType = getSplatFileTypeFromPath(input.pathOrUrl); } } - switch (fileType) { - case SplatFileType.PLY: { - const ply = new PlyReader({ fileBytes: input.fileBytes }); - await ply.parseHeader(); - let lastIndex: number | null = null; - ply.parseSplats( - ( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ) => { - const center = transformPos(new THREE.Vector3(x, y, z)); - if (withinClip(center) && withinOpacity(opacity)) { - lastIndex = splats.pushSplat(); - splats.setCenter(lastIndex, center.x, center.y, center.z); - const scales = transformScales( - new THREE.Vector3(scaleX, scaleY, scaleZ), - ); - splats.setScale(lastIndex, scales.x, scales.y, scales.z); - const quaternion = transformQuaternion( - new THREE.Quaternion(quatX, quatY, quatZ, quatW), - ); - splats.setQuaternion( - lastIndex, - quaternion.x, - quaternion.y, - quaternion.z, - quaternion.w, - ); - splats.setOpacity(lastIndex, opacity); - splats.setColor(lastIndex, r, g, b); - } else { - lastIndex = null; - } - }, - (index, sh1, sh2, sh3) => { - if (sh1 && lastIndex !== null) { - splats.setSh1(lastIndex, sh1); - } - if (sh2 && lastIndex !== null) { - splats.setSh2(lastIndex, sh2); - } - if (sh3 && lastIndex !== null) { - splats.setSh3(lastIndex, sh3); - } - }, - ); - break; - } - case SplatFileType.SPZ: { - const spz = new SpzReader({ fileBytes: input.fileBytes }); - await spz.parseHeader(); - const mapping = new Int32Array(spz.numSplats); - mapping.fill(-1); - const centers = new Float32Array(spz.numSplats * 3); - const center = new THREE.Vector3(); - spz.parseSplats( - (index, x, y, z) => { - const center = transformPos(new THREE.Vector3(x, y, z)); - centers[index * 3] = center.x; - centers[index * 3 + 1] = center.y; - centers[index * 3 + 2] = center.z; - }, - (index, alpha) => { - center.fromArray(centers, index * 3); - if (withinClip(center) && withinOpacity(alpha)) { - mapping[index] = splats.pushSplat(); - splats.setCenter(mapping[index], center.x, center.y, center.z); - splats.setOpacity(mapping[index], alpha); - } - }, - (index, r, g, b) => { - if (mapping[index] >= 0) { - splats.setColor(mapping[index], r, g, b); - } - }, - (index, scaleX, scaleY, scaleZ) => { - if (mapping[index] >= 0) { - const scales = transformScales( - new THREE.Vector3(scaleX, scaleY, scaleZ), - ); - splats.setScale(mapping[index], scales.x, scales.y, scales.z); - } - }, - (index, quatX, quatY, quatZ, quatW) => { - if (mapping[index] >= 0) { - const quaternion = transformQuaternion( - new THREE.Quaternion(quatX, quatY, quatZ, quatW), - ); - splats.setQuaternion( - mapping[index], - quaternion.x, - quaternion.y, - quaternion.z, - quaternion.w, - ); - } - }, - (index, sh1, sh2, sh3) => { - if (mapping[index] >= 0) { - splats.setSh1(mapping[index], sh1); - if (sh2) { - splats.setSh2(mapping[index], sh2); - } - if (sh3) { - splats.setSh3(mapping[index], sh3); - } - } - }, - ); - break; - } - case SplatFileType.SPLAT: - decodeAntiSplat( - input.fileBytes, - (numSplats) => {}, - ( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ) => { - const center = transformPos(new THREE.Vector3(x, y, z)); - if (withinClip(center) && withinOpacity(opacity)) { - const index = splats.pushSplat(); - splats.setCenter(index, center.x, center.y, center.z); - const scales = transformScales( - new THREE.Vector3(scaleX, scaleY, scaleZ), - ); - splats.setScale(index, scales.x, scales.y, scales.z); - const quaternion = transformQuaternion( - new THREE.Quaternion(quatX, quatY, quatZ, quatW), - ); - splats.setQuaternion( - index, - quaternion.x, - quaternion.y, - quaternion.z, - quaternion.w, - ); - splats.setOpacity(index, opacity); - splats.setColor(index, r, g, b); - } - }, - ); - break; - case SplatFileType.KSPLAT: { - let lastIndex: number | null = null; - decodeKsplat( - input.fileBytes, - (numSplats) => {}, - ( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ) => { - const center = transformPos(new THREE.Vector3(x, y, z)); - if (withinClip(center) && withinOpacity(opacity)) { - lastIndex = splats.pushSplat(); - splats.setCenter(lastIndex, center.x, center.y, center.z); - const scales = transformScales( - new THREE.Vector3(scaleX, scaleY, scaleZ), - ); - splats.setScale(lastIndex, scales.x, scales.y, scales.z); - const quaternion = transformQuaternion( - new THREE.Quaternion(quatX, quatY, quatZ, quatW), - ); - splats.setQuaternion( - lastIndex, - quaternion.x, - quaternion.y, - quaternion.z, - quaternion.w, - ); - splats.setOpacity(lastIndex, opacity); - splats.setColor(lastIndex, r, g, b); - } else { - lastIndex = null; - } - }, - (index, sh1, sh2, sh3) => { - if (lastIndex !== null) { - splats.setSh1(lastIndex, sh1); - if (sh2) { - splats.setSh2(lastIndex, sh2); - } - if (sh3) { - splats.setSh3(lastIndex, sh3); - } - } - }, - ); - break; - } - default: - throw new Error(`transcodeSpz not implemented for ${fileType}`); + const decoder = decode_to_gsplatarray(fileType, input.pathOrUrl); + const fileBytes = input.fileBytes; + const CHUNK_SIZE = 1048576; // 1 MB + for (let i = 0; i < fileBytes.length; i += CHUNK_SIZE) { + decoder.push( + fileBytes.subarray(i, Math.min(i + CHUNK_SIZE, fileBytes.length)), + ); } - } + const decoded = decoder.finish(); - const shDegree = Math.min( - maxSh ?? 3, - splats.sh3 ? 3 : splats.sh2 ? 2 : splats.sh1 ? 1 : 0, - ); - const spz = new SpzWriter({ - numSplats: splats.numSplats, - shDegree, - fractionalBits, - flagAntiAlias: true, - }); + decoded.transform({ + translation: translate, + rotation: quaternion, + scale, + clip, + opacityThreshold: opacityThreshold ?? 0, + }); - for (let i = 0; i < splats.numSplats; ++i) { - const i3 = i * 3; - const i4 = i * 4; - spz.setCenter( - i, - splats.centers[i3], - splats.centers[i3 + 1], - splats.centers[i3 + 2], - ); - spz.setScale( - i, - splats.scales[i3], - splats.scales[i3 + 1], - splats.scales[i3 + 2], - ); - spz.setQuat( - i, - splats.quaternions[i4], - splats.quaternions[i4 + 1], - splats.quaternions[i4 + 2], - splats.quaternions[i4 + 3], - ); - spz.setAlpha(i, splats.opacities[i]); - spz.setRgb( - i, - splats.colors[i3], - splats.colors[i3 + 1], - splats.colors[i3 + 2], - ); - if (splats.sh1 && shDegree >= 1) { - spz.setSh( - i, - splats.sh1.slice(i * 9, (i + 1) * 9), - shDegree >= 2 && splats.sh2 - ? splats.sh2.slice(i * 15, (i + 1) * 15) - : undefined, - shDegree >= 3 && splats.sh3 - ? splats.sh3.slice(i * 21, (i + 1) * 21) - : undefined, - ); - } + splatArrays.push(decoded); } - const spzBytes = await spz.finalize(); - return { fileBytes: spzBytes, clippedCount: spz.clippedCount }; + // Combine decoded splat arrays + const finalSplats = splatArrays[0]; + for (let i = 1; i < splatArrays.length; i++) { + finalSplats.concat(splatArrays[i]); + } + + const spzBytes = finalSplats.encode_to_spz(maxSh ?? 3, fractionalBits); + + return { fileBytes: spzBytes, clippedCount: 0 }; +} + +export function writeSpz( + packedSplats: PackedSplats, + maxSh?: number, + fractionalBits?: number, +) { + if (!packedSplats.packedArray) { + throw new Error(""); + } + const gsplats = packedsplats_to_gsplatarray( + packedSplats.numSplats, + packedSplats.packedArray, + packedSplats.extra, + packedSplats.splatEncoding, + ); + const spzBytes = gsplats.encode_to_spz(maxSh ?? 3, fractionalBits ?? 12); + return { fileBytes: spzBytes }; } diff --git a/src/worker.ts b/src/worker.ts index 1801a758..a6825d4e 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -96,35 +96,6 @@ function sortSplats32({ return { activeSplats, readback, ordering }; } -async function fetchRange({ - url, - requestHeader, - withCredentials, - offset, - bytes, -}: { - url: string; - requestHeader?: Record; - withCredentials?: string; - offset?: number; - bytes?: number; -}): Promise { - const request = new Request(url, { - headers: requestHeader ? new Headers(requestHeader) : undefined, - credentials: withCredentials ? "include" : "same-origin", - }); - if (offset !== undefined && bytes !== undefined) { - request.headers.set("Range", `bytes=${offset}-${offset + bytes - 1}`); - } - const response = await fetch(request); - if (!response.ok || !response.body) { - throw new Error( - `Failed to fetch "${url}": ${response.status} ${response.statusText}`, - ); - } - return new Uint8Array(await response.arrayBuffer()); -} - async function decodeBytesUrl({ decoder, fileBytes,