From 82fa2369df2e2148c9a3e2ac858cbdd63b5d3f59 Mon Sep 17 00:00:00 2001 From: Marco Fugaro Date: Fri, 20 Mar 2026 05:37:04 +0100 Subject: [PATCH] WebGPURenderer: Add compute shader bounds check (#33186) Co-authored-by: sunag --- examples/webgpu_compute_cloth.html | 18 +-- examples/webgpu_compute_particles_fluid.html | 31 +---- src/nodes/core/IndexNode.js | 3 +- src/nodes/gpgpu/BarrierNode.js | 4 + src/nodes/gpgpu/ComputeNode.js | 113 +++++++++++------- src/renderers/common/ComputePipeline.js | 2 +- src/renderers/webgpu/WebGPUBackend.js | 4 +- src/renderers/webgpu/nodes/WGSLNodeBuilder.js | 8 ++ 8 files changed, 91 insertions(+), 92 deletions(-) diff --git a/examples/webgpu_compute_cloth.html b/examples/webgpu_compute_cloth.html index 47d4f51a58eba0..8d8c5ef36b5ccc 100644 --- a/examples/webgpu_compute_cloth.html +++ b/examples/webgpu_compute_cloth.html @@ -33,7 +33,7 @@ import * as THREE from 'three/webgpu'; - import { Fn, If, Return, instancedArray, instanceIndex, uniform, select, attribute, uint, Loop, float, transformNormalToView, cross, triNoise3D, time } from 'three/tsl'; + import { Fn, If, Return, instancedArray, instanceIndex, uniform, select, attribute, Loop, float, transformNormalToView, cross, triNoise3D, time } from 'three/tsl'; import { Inspector } from 'three/addons/inspector/Inspector.js'; @@ -307,14 +307,6 @@ // This shader computes a force for each spring, depending on the distance between the two vertices connected by that spring and the targeted rest length computeSpringForces = Fn( () => { - If( instanceIndex.greaterThanEqual( uint( springCount ) ), () => { - - // compute Shaders are executed in groups of 64, so instanceIndex might be bigger than the amount of springs. - // in that case, return. - Return(); - - } ); - const vertexIds = springVertexIdBuffer.element( instanceIndex ); const restLength = springRestLengthBuffer.element( instanceIndex ); @@ -335,14 +327,6 @@ // In the end it adds the force to the vertex' position. computeVertexForces = Fn( () => { - If( instanceIndex.greaterThanEqual( uint( vertexCount ) ), () => { - - // compute Shaders are executed in groups of 64, so instanceIndex might be bigger than the amount of vertices. - // in that case, return. - Return(); - - } ); - const params = vertexParamsBuffer.element( instanceIndex ).toVar(); const isFixed = params.x; const springCount = params.y; diff --git a/examples/webgpu_compute_particles_fluid.html b/examples/webgpu_compute_particles_fluid.html index ee02b41800c7be..4ff83654da73e8 100644 --- a/examples/webgpu_compute_particles_fluid.html +++ b/examples/webgpu_compute_particles_fluid.html @@ -33,7 +33,7 @@ import * as THREE from 'three/webgpu'; - import { Fn, If, Return, instancedArray, instanceIndex, uniform, attribute, uint, float, clamp, struct, atomicStore, int, ivec3, array, vec3, atomicAdd, Loop, atomicLoad, max, pow, mat3, vec4, cross, step, storage } from 'three/tsl'; + import { Fn, If, Return, instancedArray, instanceIndex, uniform, attribute, float, clamp, struct, atomicStore, int, ivec3, array, vec3, atomicAdd, Loop, atomicLoad, max, pow, mat3, vec4, cross, step, storage } from 'three/tsl'; import { Inspector } from 'three/addons/inspector/Inspector.js'; @@ -132,6 +132,9 @@ gui.add( params, 'particleCount', 4096, maxParticles, 4096 ).onChange( value => { particleMesh.count = value; + p2g1Kernel.count = value; + p2g2Kernel.count = value; + g2pKernel.count = value; particleCountUniform.value = value; } ); @@ -219,12 +222,6 @@ const cellCount = gridSize.x * gridSize.y * gridSize.z; clearGridKernel = Fn( () => { - If( instanceIndex.greaterThanEqual( uint( cellCount ) ), () => { - - Return(); - - } ); - atomicStore( cellBuffer.element( instanceIndex ).get( 'x' ), 0 ); atomicStore( cellBuffer.element( instanceIndex ).get( 'y' ), 0 ); atomicStore( cellBuffer.element( instanceIndex ).get( 'z' ), 0 ); @@ -234,11 +231,6 @@ p2g1Kernel = Fn( () => { - If( instanceIndex.greaterThanEqual( particleCountUniform ), () => { - - Return(); - - } ); const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toConst( 'particlePosition' ); const particleVelocity = particleBuffer.element( instanceIndex ).get( 'velocity' ).toConst( 'particleVelocity' ); const C = particleBuffer.element( instanceIndex ).get( 'C' ).toConst( 'C' ); @@ -282,11 +274,6 @@ p2g2Kernel = Fn( () => { - If( instanceIndex.greaterThanEqual( particleCountUniform ), () => { - - Return(); - - } ); const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toConst( 'particlePosition' ); const gridPosition = particlePosition.mul( gridSizeUniform ).toVar(); @@ -353,11 +340,6 @@ updateGridKernel = Fn( () => { - If( instanceIndex.greaterThanEqual( uint( cellCount ) ), () => { - - Return(); - - } ); const cell = cellBuffer.element( instanceIndex ); const mass = decodeFixedPoint( atomicLoad( cell.get( 'mass' ) ) ).toConst(); If( mass.lessThanEqual( 0 ), () => { @@ -412,11 +394,6 @@ g2pKernel = Fn( () => { - If( instanceIndex.greaterThanEqual( particleCountUniform ), () => { - - Return(); - - } ); const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toVar( 'particlePosition' ); const gridPosition = particlePosition.mul( gridSizeUniform ).toVar(); const particleVelocity = vec3( 0 ).toVar(); diff --git a/src/nodes/core/IndexNode.js b/src/nodes/core/IndexNode.js index 5908cc694d75ea..38c4ed15962b26 100644 --- a/src/nodes/core/IndexNode.js +++ b/src/nodes/core/IndexNode.js @@ -1,5 +1,6 @@ import Node from './Node.js'; -import { nodeImmutable, varying } from '../tsl/TSLBase.js'; +import { nodeImmutable } from '../tsl/TSLCore.js'; +import { varying } from './VaryingNode.js'; /** * This class represents shader indices of different types. The following predefined node diff --git a/src/nodes/gpgpu/BarrierNode.js b/src/nodes/gpgpu/BarrierNode.js index f8e70057726893..56c1ee0cdff872 100644 --- a/src/nodes/gpgpu/BarrierNode.js +++ b/src/nodes/gpgpu/BarrierNode.js @@ -21,6 +21,8 @@ class BarrierNode extends Node { this.scope = scope; + this.isBarrierNode = true; + } generate( builder ) { @@ -28,6 +30,8 @@ class BarrierNode extends Node { const { scope } = this; const { renderer } = builder; + builder.allowEarlyReturns = false; + if ( renderer.backend.isWebGLBackend === true ) { builder.addFlowCode( `\t// ${scope}Barrier \n` ); diff --git a/src/nodes/gpgpu/ComputeNode.js b/src/nodes/gpgpu/ComputeNode.js index dcc78ae70b1ac8..5c19a94454a945 100644 --- a/src/nodes/gpgpu/ComputeNode.js +++ b/src/nodes/gpgpu/ComputeNode.js @@ -1,11 +1,13 @@ import Node from '../core/Node.js'; +import { instanceIndex } from '../core/IndexNode.js'; import StackTrace from '../core/StackTrace.js'; +import { uniform } from '../core/UniformNode.js'; import { NodeUpdateType } from '../core/constants.js'; import { addMethodChaining, nodeObject } from '../tsl/TSLCore.js'; import { warn, error } from '../../utils.js'; /** - * TODO + * Represents a compute shader node. * * @augments Node */ @@ -20,8 +22,8 @@ class ComputeNode extends Node { /** * Constructs a new compute node. * - * @param {Node} computeNode - TODO - * @param {Array} workgroupSize - TODO. + * @param {Node} computeNode - The node that defines the compute shader logic. + * @param {Array} workgroupSize - An array defining the X, Y, and Z dimensions of the workgroup for compute shader execution. */ constructor( computeNode, workgroupSize ) { @@ -37,15 +39,14 @@ class ComputeNode extends Node { this.isComputeNode = true; /** - * TODO + * The node that defines the compute shader logic. * * @type {Node} */ this.computeNode = computeNode; - /** - * TODO + * An array defining the X, Y, and Z dimensions of the workgroup for compute shader execution. * * @type {Array} * @default [ 64 ] @@ -53,14 +54,23 @@ class ComputeNode extends Node { this.workgroupSize = workgroupSize; /** - * TODO + * The total number of threads (invocations) to execute. If it is a number, it will be used + * to automatically generate bounds checking against `instanceIndex`. * * @type {number|Array} */ this.count = null; /** - * TODO + * The dispatch size for workgroups on X, Y, and Z axes. + * Used directly if `count` is not provided. + * + * @type {number|Array} + */ + this.dispatchSize = null; + + /** + * The version of the node. * * @type {number} */ @@ -84,36 +94,19 @@ class ComputeNode extends Node { this.updateBeforeType = NodeUpdateType.OBJECT; /** - * TODO + * A callback executed when the compute node finishes initialization. * * @type {?Function} */ this.onInitFunction = null; - } - - /** - * TODO - * - * @param {number|Array} count - Array with [ x, y, z ] values for dispatch or a single number for the count - * @return {ComputeNode} - */ - setCount( count ) { - - this.count = count; - - return this; - - } - - /** - * TODO - * - * @return {number|Array} - */ - getCount() { - - return this.count; + /** + * A uniform node holding the dispatch count for bounds checking. + * Created automatically when `count` is a number. + * + * @type {?UniformNode} + */ + this.countNode = null; } @@ -156,9 +149,9 @@ class ComputeNode extends Node { } /** - * TODO + * Sets the callback to run during initialization. * - * @param {Function} callback - TODO. + * @param {Function} callback - The callback function. * @return {ComputeNode} A reference to this node. */ onInit( callback ) { @@ -182,6 +175,12 @@ class ComputeNode extends Node { setup( builder ) { + if ( this.count !== null && this.countNode === null ) { + + this.countNode = uniform( this.count, 'uint' ).onObjectUpdate( () => this.count ); + + } + const result = this.computeNode.build( builder ); if ( result ) { @@ -211,6 +210,16 @@ class ComputeNode extends Node { } + if ( this.count !== null && builder.allowEarlyReturns === true ) { + + const countSnippet = this.countNode.build( builder, 'uint' ); + const indexSnippet = instanceIndex.build( builder, 'uint' ); + + builder.flow.code = `${ builder.tab }if ( ${ indexSnippet } >= ${ countSnippet } ) { return; }\n\n${ builder.flow.code }`; + + } + + } else { const properties = builder.getNodeProperties( this ); @@ -235,9 +244,9 @@ export default ComputeNode; * * @tsl * @function - * @param {Node} node - TODO - * @param {Array} [workgroupSize=[64]] - TODO. - * @returns {AtomicFunctionNode} + * @param {Node} node - The TSL logic for the compute shader. + * @param {Array} [workgroupSize=[64]] - The workgroup size. + * @returns {ComputeNode} */ export const computeKernel = ( node, workgroupSize = [ 64 ] ) => { @@ -274,12 +283,28 @@ export const computeKernel = ( node, workgroupSize = [ 64 ] ) => { * * @tsl * @function - * @param {Node} node - TODO - * @param {number|Array} count - TODO. - * @param {Array} [workgroupSize=[64]] - TODO. - * @returns {AtomicFunctionNode} - */ -export const compute = ( node, count, workgroupSize ) => computeKernel( node, workgroupSize ).setCount( count ); + * @param {Node} node - The TSL logic for the compute shader. + * @param {number|Array} count - The compute count or dispatch size. + * @param {Array} [workgroupSize=[64]] - The workgroup size. + * @returns {ComputeNode} +, */ +export const compute = ( node, count, workgroupSize ) => { + + const computeNode = computeKernel( node, workgroupSize ); + + if ( typeof count === 'number' ) { + + computeNode.count = count; + + } else { + + computeNode.dispatchSize = count; + + } + + return computeNode; + +}; addMethodChaining( 'compute', compute ); addMethodChaining( 'computeKernel', computeKernel ); diff --git a/src/renderers/common/ComputePipeline.js b/src/renderers/common/ComputePipeline.js index 8dfe198b859db7..1d22585a4227b2 100644 --- a/src/renderers/common/ComputePipeline.js +++ b/src/renderers/common/ComputePipeline.js @@ -9,7 +9,7 @@ import Pipeline from './Pipeline.js'; class ComputePipeline extends Pipeline { /** - * Constructs a new render pipeline. + * Constructs a new compute pipeline. * * @param {string} cacheKey - The pipeline's cache key. * @param {ProgrammableStage} computeProgram - The pipeline's compute shader. diff --git a/src/renderers/webgpu/WebGPUBackend.js b/src/renderers/webgpu/WebGPUBackend.js index 49a5cdd9401abe..7ee695b54b008c 100644 --- a/src/renderers/webgpu/WebGPUBackend.js +++ b/src/renderers/webgpu/WebGPUBackend.js @@ -1425,13 +1425,13 @@ class WebGPUBackend extends Backend { if ( dispatchSize === null ) { - dispatchSize = computeNode.count; + dispatchSize = computeNode.dispatchSize || computeNode.count; } // When the dispatchSize is set with a StorageBuffer from the GPU. - if ( dispatchSize && typeof dispatchSize === 'object' && dispatchSize.isIndirectStorageBufferAttribute ) { + if ( dispatchSize && dispatchSize.isIndirectStorageBufferAttribute ) { const dispatchBuffer = this.get( dispatchSize ).buffer; diff --git a/src/renderers/webgpu/nodes/WGSLNodeBuilder.js b/src/renderers/webgpu/nodes/WGSLNodeBuilder.js index 936d0a34caf2ca..d33cc95138c082 100644 --- a/src/renderers/webgpu/nodes/WGSLNodeBuilder.js +++ b/src/renderers/webgpu/nodes/WGSLNodeBuilder.js @@ -230,6 +230,14 @@ class WGSLNodeBuilder extends NodeBuilder { */ this.scopedArrays = new Map(); + /** + * A flag that indicates that early returns are allowed. + * + * @type {boolean} + * @default true + */ + this.allowEarlyReturns = true; + } /**