Skip to content

Commit 7e75e52

Browse files
authored
Validate Resource commands and properly account for element size (#99)
1 parent 41327a4 commit 7e75e52

File tree

3 files changed

+126
-26
lines changed

3 files changed

+126
-26
lines changed

compiler.ts

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,21 +73,54 @@ type BindingDescriptor = {
7373

7474
export type Bindings = Map<string, GPUBindGroupLayoutEntry>;
7575

76+
export type ReflectionBinding = {
77+
"kind": "uniform",
78+
"offset": number,
79+
"size": number,
80+
} | {
81+
"kind": "descriptorTableSlot",
82+
"index": number,
83+
};
84+
85+
export type ReflectionType = {
86+
"kind": "struct",
87+
"name": string,
88+
"fields": ReflectionParameter[]
89+
} | {
90+
"kind": "vector",
91+
"elementCount": number,
92+
"elementType": ReflectionType,
93+
} | {
94+
"kind": "scalar",
95+
"scalarType": `${"uint" | "int"}${8 | 16 | 32 | 64}` | `${"float"}${16 | 32 | 64}`,
96+
} | {
97+
"kind": "resource",
98+
"baseShape": "structuredBuffer",
99+
"access"?: "readWrite",
100+
"resultType": ReflectionType
101+
} | {
102+
"kind": "resource",
103+
"baseShape": "texture2D",
104+
"access"?: "readWrite"
105+
};
106+
107+
export type ReflectionParameter = {
108+
"binding": ReflectionBinding,
109+
"name": string,
110+
"type": ReflectionType,
111+
"userAttribs"?: {
112+
"arguments": any[],
113+
"name": string,
114+
}[],
115+
}
116+
76117
export type ReflectionJSON = {
77118
"entryPoints": {
78119
"name": string,
79120
"semanticName": string,
80121
"type": unknown
81122
}[],
82-
"parameters": {
83-
"binding": unknown,
84-
"name": string,
85-
"type": unknown,
86-
"userAttribs"?: {
87-
"arguments": any[],
88-
"name": string,
89-
}[],
90-
}[],
123+
"parameters": ReflectionParameter[],
91124
};
92125

93126

try-slang.ts

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ async function execFrame(timeMS: number) {
339339

340340
let resource = allocatedResources.get(command.resourceName);
341341
if (resource instanceof GPUBuffer) {
342-
size = [resource.size / 4, 1, 1];
342+
let elementSize = command.elementSize || 4;
343+
size = [resource.size / elementSize, 1, 1];
343344
}
344345
else if (resource instanceof GPUTexture) {
345346
size = [resource.width, resource.height, 1];
@@ -496,7 +497,11 @@ function checkShaderType(userSource: string) {
496497
}
497498

498499
export type ParsedCommand = {
499-
"type": "ZEROS" | "RAND",
500+
"type": "ZEROS",
501+
"count": number,
502+
"elementSize": number,
503+
} | {
504+
"type": "RAND",
500505
"count": number,
501506
} | {
502507
"type": "BLACK",
@@ -521,7 +526,7 @@ async function processResourceCommands(pipeline: ComputePipeline | GraphicsPipel
521526

522527
for (const { resourceName, parsedCommand } of resourceCommands) {
523528
if (parsedCommand.type === "ZEROS") {
524-
const elementSize = 4; // Assuming 4 bytes per element (e.g., float) TODO: infer from type.
529+
const elementSize = parsedCommand.elementSize;
525530
const bindingInfo = resourceBindings.get(resourceName);
526531
if (!bindingInfo) {
527532
throw new Error(`Resource ${resourceName} is not defined in the bindings.`);
@@ -539,12 +544,7 @@ async function processResourceCommands(pipeline: ComputePipeline | GraphicsPipel
539544
safeSet(allocatedResources, resourceName, buffer);
540545

541546
// Initialize the buffer with zeros.
542-
let zeros: BufferSource;
543-
if (elementSize == 4) {
544-
zeros = new Float32Array(parsedCommand.count);
545-
} else {
546-
throw new Error("Element size isn't handled");
547-
}
547+
let zeros: BufferSource = new Uint8Array(parsedCommand.count * elementSize);
548548
pipeline.device.queue.writeBuffer(buffer, 0, zeros);
549549
} else if (parsedCommand.type === "BLACK") {
550550
const size = parsedCommand.width * parsedCommand.height;
@@ -703,7 +703,7 @@ async function processResourceCommands(pipeline: ComputePipeline | GraphicsPipel
703703
}
704704
} else {
705705
// exhaustiveness check
706-
let x: never = parsedCommand.type;
706+
let x: never = parsedCommand;
707707
throw new Error("Invalid resource command type");
708708
}
709709
}
@@ -785,7 +785,7 @@ export let onRun = () => {
785785
resourceCommands = getCommandsFromAttributes(ret.reflection);
786786

787787
try {
788-
callCommands = parseCallCommands(userSource);
788+
callCommands = parseCallCommands(userSource, ret.reflection);
789789
}
790790
catch (error: any) {
791791
throw new Error("Error while parsing '//! CALL' commands: " + error.message);

util.ts

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { ReflectionJSON } from './compiler.js';
1+
import { ReflectionJSON, ReflectionType } from './compiler.js';
22
import { ParsedCommand } from './try-slang.js';
33

44
export function configContext(device: GPUDevice, canvas: HTMLCanvasElement) {
@@ -42,6 +42,42 @@ function reinterpretUint32AsFloat(uint32: number) {
4242
return float32View[0];
4343
}
4444

45+
function roundUpToNearest(x: number, nearest: number){
46+
return Math.ceil(x / nearest) * nearest;
47+
}
48+
49+
function getSize(reflectionType: ReflectionType): number {
50+
if(reflectionType.kind == "resource") {
51+
throw new Error("unimplemented");
52+
} else if(reflectionType.kind == "scalar") {
53+
const bitsMatch = reflectionType.scalarType.match(/\d+$/);
54+
if(bitsMatch == null) {
55+
throw new Error("Could not get bit count out of scalar type");
56+
}
57+
return parseInt(bitsMatch[0]) / 8;
58+
} else if(reflectionType.kind == "struct") {
59+
const alignment = reflectionType.fields.map((f) => {
60+
if(f.binding.kind == "uniform") return f.binding.size;
61+
else throw new Error("Invalid state")
62+
}).reduce((a, b) => Math.max(a, b));
63+
64+
const unalignedSize = reflectionType.fields.map((f) => {
65+
if(f.binding.kind == "uniform") return f.binding.offset + f.binding.size;
66+
else throw new Error("Invalid state")
67+
}).reduce((a, b) => Math.max(a, b));
68+
69+
return roundUpToNearest(unalignedSize, alignment);
70+
} else if(reflectionType.kind == "vector") {
71+
if(reflectionType.elementCount == 3) {
72+
return 4 * getSize(reflectionType.elementType);
73+
}
74+
return reflectionType.elementCount * getSize(reflectionType.elementType);
75+
} else {
76+
let x:never = reflectionType;
77+
throw new Error("Cannot get size of unrecognized reflection type");
78+
}
79+
}
80+
4581
/**
4682
* Here are some patterns we support:
4783
*
@@ -63,20 +99,41 @@ export function getCommandsFromAttributes(reflection: ReflectionJSON): { resourc
6399
if (!attribute.name.startsWith("playground_")) continue;
64100

65101
let playground_attribute_name = attribute.name.slice(11);
66-
if (playground_attribute_name == "ZEROS" || playground_attribute_name == "RAND") {
102+
if (playground_attribute_name == "ZEROS") {
103+
if(parameter.type.kind != "resource" || parameter.type.baseShape != "structuredBuffer") {
104+
throw new Error(`ZEROS attribute cannot be applied to ${parameter.name}, it only supports buffers`)
105+
}
67106
command = {
68107
type: playground_attribute_name,
69108
count: attribute.arguments[0] as number,
109+
elementSize: getSize(parameter.type.resultType),
110+
};
111+
} else if (playground_attribute_name == "RAND") {
112+
if(parameter.type.kind != "resource" || parameter.type.baseShape != "structuredBuffer") {
113+
throw new Error(`RAND attribute cannot be applied to ${parameter.name}, it only supports buffers`)
114+
}
115+
if(parameter.type.resultType.kind != "scalar" || parameter.type.resultType.scalarType != "float32") {
116+
throw new Error(`RAND attribute cannot be applied to ${parameter.name}, it only supports float buffers`)
117+
}
118+
command = {
119+
type: playground_attribute_name,
120+
count: attribute.arguments[0] as number
70121
};
71122
} else if (playground_attribute_name == "BLACK") {
123+
if(parameter.type.kind != "resource" || parameter.type.baseShape != "texture2D") {
124+
throw new Error(`BLACK attribute cannot be applied to ${parameter.name}, it only supports 2D textures`)
125+
}
72126
command = {
73-
type: "BLACK",
127+
type: playground_attribute_name,
74128
width: attribute.arguments[0] as number,
75129
height: attribute.arguments[1] as number,
76130
};
77131
} else if (playground_attribute_name == "URL") {
132+
if(parameter.type.kind != "resource" || parameter.type.baseShape != "texture2D") {
133+
throw new Error(`URL attribute cannot be applied to ${parameter.name}, it only supports 2D textures`)
134+
}
78135
command = {
79-
type: "URL",
136+
type: playground_attribute_name,
80137
url: attribute.arguments[0] as string,
81138
};
82139
}
@@ -97,13 +154,14 @@ export type CallCommand = {
97154
type: "RESOURCE_BASED",
98155
fnName: string,
99156
resourceName: string,
157+
elementSize?: number,
100158
} | {
101159
type: "FIXED_SIZE",
102160
fnName: string,
103161
size: number[],
104162
};
105163

106-
export function parseCallCommands(userSource: string): CallCommand[] {
164+
export function parseCallCommands(userSource: string, reflection: ReflectionJSON): CallCommand[] {
107165
// Look for commands of the form:
108166
//
109167
// 1. //! CALL(fn-name, SIZE_OF(<resource-name>)) ==> Dispatch a compute pass with the given
@@ -122,7 +180,16 @@ export function parseCallCommands(userSource: string): CallCommand[] {
122180
const args = match[2].split(',').map(arg => arg.trim());
123181

124182
if (args[0].startsWith("SIZE_OF")) {
125-
callCommands.push({ type: "RESOURCE_BASED", fnName, resourceName: args[0].slice(8, -1) });
183+
let resourceName = args[0].slice(8, -1);
184+
let resourceReflection = reflection.parameters.find((param) => param.name == resourceName);
185+
if(resourceReflection == undefined) {
186+
throw new Error(`Cannot find resource ${resourceName} for ${fnName} CALL command`)
187+
}
188+
let elementSize: number | undefined = undefined;
189+
if(resourceReflection.type.kind == "resource" && resourceReflection.type.baseShape == "structuredBuffer") {
190+
elementSize = getSize(resourceReflection.type.resultType);
191+
}
192+
callCommands.push({ type: "RESOURCE_BASED", fnName, resourceName, elementSize });
126193
}
127194
else {
128195
callCommands.push({ type: "FIXED_SIZE", fnName, size: args.map(Number) });

0 commit comments

Comments
 (0)