Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/strands/strands_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) {
const nodeData = DAG.createNodeData({
nodeType: NodeType.STATEMENT,
statementType: StatementType.EARLY_RETURN,
dependsOn: [valueNode.id]
dependsOn: value !== undefined ? [valueNode.id] : []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind elaborating on what these changes are there to handle? Anything we should have more test cases for in the tests?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fix void return types in compute shaders. Without them, doing return; in a compute hook would crash with "Missing dataType". Most compute shaders use void (side-effects only), so the auto-spread wouldn't work without this fix.

For tests - should I add cases for void hooks with early returns? The main compute functionality already has test coverage.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, got it. Right, let's add a test for early returns, since this wasn't a case covered by any tests before. Thanks!

Copy link
Copy Markdown
Author

@aashu2006 aashu2006 Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the test cases for void compute hooks with early returns. Both tests are passing.
Thanks!

});
const earlyReturnID = DAG.getOrCreateNode(dag, nodeData);
CFG.recordInBasicBlock(cfg, cfg.currentBlock, earlyReturnID);
Expand Down Expand Up @@ -786,17 +786,21 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) {
return newStruct.id;
}
}
else if (!expectedReturnType.dataType || expectedReturnType.typeName?.trim() === 'void') {
return null;
}
else /*if(isNativeType(expectedReturnType.typeName))*/ {
if (!expectedReturnType.dataType) {
throw new Error(`Missing dataType for return type ${expectedReturnType.typeName}`);
}
const expectedTypeInfo = expectedReturnType.dataType;
return enforceReturnTypeMatch(strandsContext, expectedTypeInfo, retNode, hookType.name);
}
}
for (const { valueNode, earlyReturnID } of hook.earlyReturns) {
const id = handleRetVal(valueNode);
dag.dependsOn[earlyReturnID] = [id];
if (id !== null) {
dag.dependsOn[earlyReturnID] = [id];
} else {
dag.dependsOn[earlyReturnID] = [];
}
}
rootNodeID = userReturned ? handleRetVal(userReturned) : undefined;
const fullHookName = `${hookType.returnType.typeName} ${hookType.name}`;
Expand Down
5 changes: 1 addition & 4 deletions src/strands/strands_codegen.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,9 @@ export function generateShaderCode(strandsContext) {
let returnType;
if (hookType.returnType.properties) {
returnType = structType(hookType.returnType);
} else if (hookType.returnType.typeName === 'void') {
} else if (!hookType.returnType.dataType || hookType.returnType.typeName?.trim() === 'void') {
returnType = null;
} else {
if (!hookType.returnType.dataType) {
throw new Error(`Missing dataType for return type ${hookType.returnType.typeName}`);
}
returnType = hookType.returnType.dataType;
}

Expand Down
38 changes: 34 additions & 4 deletions src/webgpu/p5.RendererWebGPU.js
Original file line number Diff line number Diff line change
Expand Up @@ -3813,10 +3813,40 @@ ${hookUniformFields}}
const WORKGROUP_SIZE_Y = 8;
const WORKGROUP_SIZE_Z = 1;

// Calculate number of workgroups needed
const workgroupCountX = Math.ceil(x / WORKGROUP_SIZE_X);
const workgroupCountY = Math.ceil(y / WORKGROUP_SIZE_Y);
const workgroupCountZ = Math.ceil(z / WORKGROUP_SIZE_Z);
// auto spreading: if any dimension is too large or for performance optimization,
// spread total iteration count across dimensions
const totalIterations = x * y * z;
const MAX_THREADS_PER_DIM = 65535 * 8;

let px = x;
let py = y;
let pz = z;

// we spread if we exceed GPU limits OR if it involves a large 1D dispatch
const exceedsLimits = x > MAX_THREADS_PER_DIM || y > MAX_THREADS_PER_DIM || z > MAX_THREADS_PER_DIM;
const isLarge1D = totalIterations > 1024 && y === 1 && z === 1;

if (exceedsLimits || isLarge1D) {
// Always use 2D square spreading (√N × √N).
// Benchmarks showed 2D square equals or outperforms 3D cube at every
// scale tested, with simpler index reconstruction in the shader.
px = Math.ceil(Math.sqrt(totalIterations));
py = Math.ceil(totalIterations / px);
pz = 1;

if (p5.debug || exceedsLimits) {
console.warn(
`p5.js: Compute dispatch (${x}, ${y}, ${z}) auto-spread to (${px}, ${py}, 1) ` +
`to ${exceedsLimits ? 'stay within GPU limits' : 'optimize performance'}.`
);
}
}

shader.setUniform('uPhysicalCount', [px, py, pz]);

const workgroupCountX = Math.ceil(px / WORKGROUP_SIZE_X);
const workgroupCountY = Math.ceil(py / WORKGROUP_SIZE_Y);
const workgroupCountZ = Math.ceil(pz / WORKGROUP_SIZE_Z);

const commandEncoder = this.device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
Expand Down
16 changes: 10 additions & 6 deletions src/webgpu/shaders/compute.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export const baseComputeShader = `
struct ComputeUniforms {
uTotalCount: vec3<i32>,
uPhysicalCount: vec3<i32>,
}
@group(0) @binding(0) var<uniform> uniforms: ComputeUniforms;

Expand All @@ -11,16 +12,19 @@ fn main(
@builtin(workgroup_id) workgroupId: vec3<u32>,
@builtin(local_invocation_index) localIndex: u32
) {
var index = vec3<i32>(globalId);
let totalIterations = u32(uniforms.uTotalCount.x) * u32(uniforms.uTotalCount.y) * u32(uniforms.uTotalCount.z);
let physicalId = globalId.x + globalId.y * (u32(uniforms.uPhysicalCount.x)) + globalId.z * (u32(uniforms.uPhysicalCount.x) * u32(uniforms.uPhysicalCount.y));

if (
index.x >= uniforms.uTotalCount.x ||
index.y >= uniforms.uTotalCount.y ||
index.z >= uniforms.uTotalCount.z
) {
if (physicalId >= totalIterations) {
return;
}

var index = vec3<i32>(0);
index.x = i32(physicalId % u32(uniforms.uTotalCount.x));
let remainingY = physicalId / u32(uniforms.uTotalCount.x);
index.y = i32(remainingY % u32(uniforms.uTotalCount.y));
index.z = i32(remainingY / u32(uniforms.uTotalCount.y));

HOOK_iteration(index);
}
`;
10 changes: 7 additions & 3 deletions src/webgpu/strands_wgslBackend.js
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,13 @@ export const wgslBackend = {
// Generate just a semicolon (unless suppressed)
generationContext.write(semicolon);
} else if (node.statementType === StatementType.EARLY_RETURN) {
const exprNodeID = node.dependsOn[0];
const expr = this.generateExpression(generationContext, dag, exprNodeID);
generationContext.write(`return ${expr}${semicolon}`);
if (node.dependsOn && node.dependsOn.length > 0) {
const exprNodeID = node.dependsOn[0];
const expr = this.generateExpression(generationContext, dag, exprNodeID);
generationContext.write(`return ${expr}${semicolon}`);
} else {
generationContext.write(`return${semicolon}`);
}
}
},
generateAssignment(generationContext, dag, nodeID) {
Expand Down
41 changes: 41 additions & 0 deletions test/unit/webgpu/p5.Shader.js
Original file line number Diff line number Diff line change
Expand Up @@ -1228,5 +1228,46 @@ suite('WebGPU p5.Shader', function() {
});
}
});

suite('compute shaders', () => {
test('handle early return in void compute hook', async () => {
await myp5.createCanvas(5, 5, myp5.WEBGPU);

// This test verifies that buildComputeShader and p5.compute
// correctly handle void hooks with early returns without crashing
// the strands compiler or hitting type errors.
expect(() => {
const computeShader = myp5.buildComputeShader(() => {
const id = myp5.index.x;
if (id > 10) {
return; // Early return in void hook
}
}, { myp5 });

myp5.compute(computeShader, 1);
}).not.toThrow();
});

test('early return in void compute hook stops execution', async () => {
await myp5.createCanvas(5, 5, myp5.WEBGPU);
const data = myp5.createStorage([0]);

const computeShader = myp5.buildComputeShader(() => {
const buf = myp5.uniformStorage();
const id = myp5.index.x;
if (id == 0) {
buf[0] = 1.0;
return;
buf[0] = 2.0; // Should not execute
}
}, { myp5 });

computeShader.setUniform('buf', data);

expect(() => {
myp5.compute(computeShader, 1);
}).not.toThrow();
});
});
});
});