diff --git a/examples/webgpu_compute_cloth.html b/examples/webgpu_compute_cloth.html
index 47d4f51a58eba0..8d8c5ef36b5ccc 100644
--- a/examples/webgpu_compute_cloth.html
+++ b/examples/webgpu_compute_cloth.html
@@ -33,7 +33,7 @@
import * as THREE from 'three/webgpu';
- import { Fn, If, Return, instancedArray, instanceIndex, uniform, select, attribute, uint, Loop, float, transformNormalToView, cross, triNoise3D, time } from 'three/tsl';
+ import { Fn, If, Return, instancedArray, instanceIndex, uniform, select, attribute, Loop, float, transformNormalToView, cross, triNoise3D, time } from 'three/tsl';
import { Inspector } from 'three/addons/inspector/Inspector.js';
@@ -307,14 +307,6 @@
// This shader computes a force for each spring, depending on the distance between the two vertices connected by that spring and the targeted rest length
computeSpringForces = Fn( () => {
- If( instanceIndex.greaterThanEqual( uint( springCount ) ), () => {
-
- // compute Shaders are executed in groups of 64, so instanceIndex might be bigger than the amount of springs.
- // in that case, return.
- Return();
-
- } );
-
const vertexIds = springVertexIdBuffer.element( instanceIndex );
const restLength = springRestLengthBuffer.element( instanceIndex );
@@ -335,14 +327,6 @@
// In the end it adds the force to the vertex' position.
computeVertexForces = Fn( () => {
- If( instanceIndex.greaterThanEqual( uint( vertexCount ) ), () => {
-
- // compute Shaders are executed in groups of 64, so instanceIndex might be bigger than the amount of vertices.
- // in that case, return.
- Return();
-
- } );
-
const params = vertexParamsBuffer.element( instanceIndex ).toVar();
const isFixed = params.x;
const springCount = params.y;
diff --git a/examples/webgpu_compute_particles_fluid.html b/examples/webgpu_compute_particles_fluid.html
index ee02b41800c7be..4ff83654da73e8 100644
--- a/examples/webgpu_compute_particles_fluid.html
+++ b/examples/webgpu_compute_particles_fluid.html
@@ -33,7 +33,7 @@
import * as THREE from 'three/webgpu';
- import { Fn, If, Return, instancedArray, instanceIndex, uniform, attribute, uint, float, clamp, struct, atomicStore, int, ivec3, array, vec3, atomicAdd, Loop, atomicLoad, max, pow, mat3, vec4, cross, step, storage } from 'three/tsl';
+ import { Fn, If, Return, instancedArray, instanceIndex, uniform, attribute, float, clamp, struct, atomicStore, int, ivec3, array, vec3, atomicAdd, Loop, atomicLoad, max, pow, mat3, vec4, cross, step, storage } from 'three/tsl';
import { Inspector } from 'three/addons/inspector/Inspector.js';
@@ -132,6 +132,9 @@
gui.add( params, 'particleCount', 4096, maxParticles, 4096 ).onChange( value => {
particleMesh.count = value;
+ p2g1Kernel.count = value;
+ p2g2Kernel.count = value;
+ g2pKernel.count = value;
particleCountUniform.value = value;
} );
@@ -219,12 +222,6 @@
const cellCount = gridSize.x * gridSize.y * gridSize.z;
clearGridKernel = Fn( () => {
- If( instanceIndex.greaterThanEqual( uint( cellCount ) ), () => {
-
- Return();
-
- } );
-
atomicStore( cellBuffer.element( instanceIndex ).get( 'x' ), 0 );
atomicStore( cellBuffer.element( instanceIndex ).get( 'y' ), 0 );
atomicStore( cellBuffer.element( instanceIndex ).get( 'z' ), 0 );
@@ -234,11 +231,6 @@
p2g1Kernel = Fn( () => {
- If( instanceIndex.greaterThanEqual( particleCountUniform ), () => {
-
- Return();
-
- } );
const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toConst( 'particlePosition' );
const particleVelocity = particleBuffer.element( instanceIndex ).get( 'velocity' ).toConst( 'particleVelocity' );
const C = particleBuffer.element( instanceIndex ).get( 'C' ).toConst( 'C' );
@@ -282,11 +274,6 @@
p2g2Kernel = Fn( () => {
- If( instanceIndex.greaterThanEqual( particleCountUniform ), () => {
-
- Return();
-
- } );
const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toConst( 'particlePosition' );
const gridPosition = particlePosition.mul( gridSizeUniform ).toVar();
@@ -353,11 +340,6 @@
updateGridKernel = Fn( () => {
- If( instanceIndex.greaterThanEqual( uint( cellCount ) ), () => {
-
- Return();
-
- } );
const cell = cellBuffer.element( instanceIndex );
const mass = decodeFixedPoint( atomicLoad( cell.get( 'mass' ) ) ).toConst();
If( mass.lessThanEqual( 0 ), () => {
@@ -412,11 +394,6 @@
g2pKernel = Fn( () => {
- If( instanceIndex.greaterThanEqual( particleCountUniform ), () => {
-
- Return();
-
- } );
const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toVar( 'particlePosition' );
const gridPosition = particlePosition.mul( gridSizeUniform ).toVar();
const particleVelocity = vec3( 0 ).toVar();
diff --git a/src/nodes/core/IndexNode.js b/src/nodes/core/IndexNode.js
index 5908cc694d75ea..38c4ed15962b26 100644
--- a/src/nodes/core/IndexNode.js
+++ b/src/nodes/core/IndexNode.js
@@ -1,5 +1,6 @@
import Node from './Node.js';
-import { nodeImmutable, varying } from '../tsl/TSLBase.js';
+import { nodeImmutable } from '../tsl/TSLCore.js';
+import { varying } from './VaryingNode.js';
/**
* This class represents shader indices of different types. The following predefined node
diff --git a/src/nodes/gpgpu/BarrierNode.js b/src/nodes/gpgpu/BarrierNode.js
index f8e70057726893..56c1ee0cdff872 100644
--- a/src/nodes/gpgpu/BarrierNode.js
+++ b/src/nodes/gpgpu/BarrierNode.js
@@ -21,6 +21,8 @@ class BarrierNode extends Node {
this.scope = scope;
+ this.isBarrierNode = true;
+
}
generate( builder ) {
@@ -28,6 +30,8 @@ class BarrierNode extends Node {
const { scope } = this;
const { renderer } = builder;
+ builder.allowEarlyReturns = false;
+
if ( renderer.backend.isWebGLBackend === true ) {
builder.addFlowCode( `\t// ${scope}Barrier \n` );
diff --git a/src/nodes/gpgpu/ComputeNode.js b/src/nodes/gpgpu/ComputeNode.js
index dcc78ae70b1ac8..5c19a94454a945 100644
--- a/src/nodes/gpgpu/ComputeNode.js
+++ b/src/nodes/gpgpu/ComputeNode.js
@@ -1,11 +1,13 @@
import Node from '../core/Node.js';
+import { instanceIndex } from '../core/IndexNode.js';
import StackTrace from '../core/StackTrace.js';
+import { uniform } from '../core/UniformNode.js';
import { NodeUpdateType } from '../core/constants.js';
import { addMethodChaining, nodeObject } from '../tsl/TSLCore.js';
import { warn, error } from '../../utils.js';
/**
- * TODO
+ * Represents a compute shader node.
*
* @augments Node
*/
@@ -20,8 +22,8 @@ class ComputeNode extends Node {
/**
* Constructs a new compute node.
*
- * @param {Node} computeNode - TODO
- * @param {Array} workgroupSize - TODO.
+ * @param {Node} computeNode - The node that defines the compute shader logic.
+ * @param {Array} workgroupSize - An array defining the X, Y, and Z dimensions of the workgroup for compute shader execution.
*/
constructor( computeNode, workgroupSize ) {
@@ -37,15 +39,14 @@ class ComputeNode extends Node {
this.isComputeNode = true;
/**
- * TODO
+ * The node that defines the compute shader logic.
*
* @type {Node}
*/
this.computeNode = computeNode;
-
/**
- * TODO
+ * An array defining the X, Y, and Z dimensions of the workgroup for compute shader execution.
*
* @type {Array}
* @default [ 64 ]
@@ -53,14 +54,23 @@ class ComputeNode extends Node {
this.workgroupSize = workgroupSize;
/**
- * TODO
+ * The total number of threads (invocations) to execute. If it is a number, it will be used
+ * to automatically generate bounds checking against `instanceIndex`.
*
* @type {number|Array}
*/
this.count = null;
/**
- * TODO
+ * The dispatch size for workgroups on X, Y, and Z axes.
+ * Used directly if `count` is not provided.
+ *
+ * @type {number|Array}
+ */
+ this.dispatchSize = null;
+
+ /**
+ * The version of the node.
*
* @type {number}
*/
@@ -84,36 +94,19 @@ class ComputeNode extends Node {
this.updateBeforeType = NodeUpdateType.OBJECT;
/**
- * TODO
+ * A callback executed when the compute node finishes initialization.
*
* @type {?Function}
*/
this.onInitFunction = null;
- }
-
- /**
- * TODO
- *
- * @param {number|Array} count - Array with [ x, y, z ] values for dispatch or a single number for the count
- * @return {ComputeNode}
- */
- setCount( count ) {
-
- this.count = count;
-
- return this;
-
- }
-
- /**
- * TODO
- *
- * @return {number|Array}
- */
- getCount() {
-
- return this.count;
+ /**
+ * A uniform node holding the dispatch count for bounds checking.
+ * Created automatically when `count` is a number.
+ *
+ * @type {?UniformNode}
+ */
+ this.countNode = null;
}
@@ -156,9 +149,9 @@ class ComputeNode extends Node {
}
/**
- * TODO
+ * Sets the callback to run during initialization.
*
- * @param {Function} callback - TODO.
+ * @param {Function} callback - The callback function.
* @return {ComputeNode} A reference to this node.
*/
onInit( callback ) {
@@ -182,6 +175,12 @@ class ComputeNode extends Node {
setup( builder ) {
+ if ( this.count !== null && this.countNode === null ) {
+
+ this.countNode = uniform( this.count, 'uint' ).onObjectUpdate( () => this.count );
+
+ }
+
const result = this.computeNode.build( builder );
if ( result ) {
@@ -211,6 +210,16 @@ class ComputeNode extends Node {
}
+ if ( this.count !== null && builder.allowEarlyReturns === true ) {
+
+ const countSnippet = this.countNode.build( builder, 'uint' );
+ const indexSnippet = instanceIndex.build( builder, 'uint' );
+
+ builder.flow.code = `${ builder.tab }if ( ${ indexSnippet } >= ${ countSnippet } ) { return; }\n\n${ builder.flow.code }`;
+
+ }
+
+
} else {
const properties = builder.getNodeProperties( this );
@@ -235,9 +244,9 @@ export default ComputeNode;
*
* @tsl
* @function
- * @param {Node} node - TODO
- * @param {Array} [workgroupSize=[64]] - TODO.
- * @returns {AtomicFunctionNode}
+ * @param {Node} node - The TSL logic for the compute shader.
+ * @param {Array} [workgroupSize=[64]] - The workgroup size.
+ * @returns {ComputeNode}
*/
export const computeKernel = ( node, workgroupSize = [ 64 ] ) => {
@@ -274,12 +283,28 @@ export const computeKernel = ( node, workgroupSize = [ 64 ] ) => {
*
* @tsl
* @function
- * @param {Node} node - TODO
- * @param {number|Array} count - TODO.
- * @param {Array} [workgroupSize=[64]] - TODO.
- * @returns {AtomicFunctionNode}
- */
-export const compute = ( node, count, workgroupSize ) => computeKernel( node, workgroupSize ).setCount( count );
+ * @param {Node} node - The TSL logic for the compute shader.
+ * @param {number|Array} count - The compute count or dispatch size.
+ * @param {Array} [workgroupSize=[64]] - The workgroup size.
+ * @returns {ComputeNode}
+, */
+export const compute = ( node, count, workgroupSize ) => {
+
+ const computeNode = computeKernel( node, workgroupSize );
+
+ if ( typeof count === 'number' ) {
+
+ computeNode.count = count;
+
+ } else {
+
+ computeNode.dispatchSize = count;
+
+ }
+
+ return computeNode;
+
+};
addMethodChaining( 'compute', compute );
addMethodChaining( 'computeKernel', computeKernel );
diff --git a/src/renderers/common/ComputePipeline.js b/src/renderers/common/ComputePipeline.js
index 8dfe198b859db7..1d22585a4227b2 100644
--- a/src/renderers/common/ComputePipeline.js
+++ b/src/renderers/common/ComputePipeline.js
@@ -9,7 +9,7 @@ import Pipeline from './Pipeline.js';
class ComputePipeline extends Pipeline {
/**
- * Constructs a new render pipeline.
+ * Constructs a new compute pipeline.
*
* @param {string} cacheKey - The pipeline's cache key.
* @param {ProgrammableStage} computeProgram - The pipeline's compute shader.
diff --git a/src/renderers/webgpu/WebGPUBackend.js b/src/renderers/webgpu/WebGPUBackend.js
index 49a5cdd9401abe..7ee695b54b008c 100644
--- a/src/renderers/webgpu/WebGPUBackend.js
+++ b/src/renderers/webgpu/WebGPUBackend.js
@@ -1425,13 +1425,13 @@ class WebGPUBackend extends Backend {
if ( dispatchSize === null ) {
- dispatchSize = computeNode.count;
+ dispatchSize = computeNode.dispatchSize || computeNode.count;
}
// When the dispatchSize is set with a StorageBuffer from the GPU.
- if ( dispatchSize && typeof dispatchSize === 'object' && dispatchSize.isIndirectStorageBufferAttribute ) {
+ if ( dispatchSize && dispatchSize.isIndirectStorageBufferAttribute ) {
const dispatchBuffer = this.get( dispatchSize ).buffer;
diff --git a/src/renderers/webgpu/nodes/WGSLNodeBuilder.js b/src/renderers/webgpu/nodes/WGSLNodeBuilder.js
index 936d0a34caf2ca..d33cc95138c082 100644
--- a/src/renderers/webgpu/nodes/WGSLNodeBuilder.js
+++ b/src/renderers/webgpu/nodes/WGSLNodeBuilder.js
@@ -230,6 +230,14 @@ class WGSLNodeBuilder extends NodeBuilder {
*/
this.scopedArrays = new Map();
+ /**
+ * A flag that indicates that early returns are allowed.
+ *
+ * @type {boolean}
+ * @default true
+ */
+ this.allowEarlyReturns = true;
+
}
/**