|
| 1 | +import { createSingletonBuffer, WebGPUBufferSet } from "./buffertools"; |
| 2 | +import { StatefulGPU } from "./lib"; |
| 3 | + |
| 4 | +type TinyForestParams = { |
| 5 | + nTrees: number; |
| 6 | + depth: number; |
| 7 | + // The number of features to consider at each split. |
| 8 | + maxFeatures: number; |
| 9 | + D: number; |
| 10 | +} |
| 11 | + |
| 12 | +const defaultTinyForestParams : TinyForestParams = { |
| 13 | + nTrees: 128, |
| 14 | + depth: 8, |
| 15 | + maxFeatures: 32, |
| 16 | + D: 768, |
| 17 | +} |
| 18 | + |
| 19 | +export class TinyForest extends StatefulGPU { |
| 20 | + params: TinyForestParams; |
| 21 | + |
| 22 | + private _bootstrapSamples?: GPUBuffer; // On the order of 100 KB |
| 23 | + protected _forests?: GPUBuffer // On the order of 10 MB. |
| 24 | + // private trainedThrough: number = 0; |
| 25 | + constructor( |
| 26 | + device: GPUDevice, |
| 27 | + bufferSize = 1024 * 1024 * 256, |
| 28 | + t: Partial<TinyForestParams> = {}) { |
| 29 | + super(device, bufferSize) |
| 30 | + this.params = {...defaultTinyForestParams, ...t} |
| 31 | + this.initializeForestsToZero() |
| 32 | + this.bufferSet = new WebGPUBufferSet(device, bufferSize); |
| 33 | + } |
| 34 | + |
| 35 | + countPipeline(): GPUComputePipeline { |
| 36 | + const { device } = this; |
| 37 | + // const { maxFeatures, nTrees } = this.params |
| 38 | + // const OPTIONS = 2; |
| 39 | + // const countBuffer = device.createBuffer({ |
| 40 | + // size: OPTIONS * maxFeatures * nTrees * 4, |
| 41 | + // usage: GPUBufferUsage.STORAGE & GPUBufferUsage.COPY_SRC, |
| 42 | + // mappedAtCreation: false |
| 43 | + // }); |
| 44 | + |
| 45 | + const layout = device.createBindGroupLayout({ |
| 46 | + entries: [ |
| 47 | + { |
| 48 | + // features buffer; |
| 49 | + binding: 0, |
| 50 | + visibility: GPUShaderStage.COMPUTE, |
| 51 | + buffer: { type: 'storage' } |
| 52 | + }, |
| 53 | + { |
| 54 | + // dims to check array; |
| 55 | + binding: 1, |
| 56 | + visibility: GPUShaderStage.COMPUTE, |
| 57 | + buffer: { type: 'storage' } |
| 58 | + }, |
| 59 | + { |
| 60 | + // output count buffer. |
| 61 | + binding: 2, |
| 62 | + visibility: GPUShaderStage.COMPUTE, |
| 63 | + buffer: { type: 'storage' } |
| 64 | + } |
| 65 | + ] |
| 66 | + }) |
| 67 | + |
| 68 | + // const subsetsToCheck = this.chooseNextFeatures(); |
| 69 | + const pipelineLayout = device.createPipelineLayout({ bindGroupLayouts: [layout] }); |
| 70 | + |
| 71 | + const shaderModule = device.createShaderModule({ code: ` |
| 72 | + @group(0) @binding(0) var<storage, read> features: array<u32>; |
| 73 | + @group(0) @binding(1) var<storage, read> dimsToCheck: array<u16>; |
| 74 | + @group(0) @binding(2) var<storage, write> counts: array<u32>; |
| 75 | +
|
| 76 | + @compute @workgroup_size(64) |
| 77 | + //TODOD HERE |
| 78 | + ` }); |
| 79 | + |
| 80 | + |
| 81 | + return device.createComputePipeline({ |
| 82 | + layout: pipelineLayout, |
| 83 | + compute: { |
| 84 | + module: shaderModule, |
| 85 | + entryPoint: 'main' |
| 86 | + } |
| 87 | + }); |
| 88 | + } |
| 89 | + |
| 90 | + //@ts-expect-error foo |
| 91 | + private chooseNextFeatures(n = 32) { |
| 92 | + console.log({n}) |
| 93 | + const { maxFeatures, nTrees, D } = this.params; |
| 94 | + const features = new Uint16Array(maxFeatures * D); |
| 95 | + for (let i = 0; i < nTrees; i++) { |
| 96 | + const set = new Set<number>(); |
| 97 | + while (set.size < maxFeatures) { |
| 98 | + set.add(Math.floor(Math.random() * D)); |
| 99 | + } |
| 100 | + const arr = new Uint16Array([...set].sort()); |
| 101 | + features.set(arr, i * maxFeatures); |
| 102 | + } |
| 103 | + return createSingletonBuffer( |
| 104 | + this.device, |
| 105 | + features, |
| 106 | + GPUBufferUsage.STORAGE |
| 107 | + ) |
| 108 | + } |
| 109 | + |
| 110 | + |
| 111 | + |
| 112 | + initializeForestsToZero() { |
| 113 | + // Each tree is a set of bits; For every possible configuration |
| 114 | + // the first D indicating |
| 115 | + // the desired outcome for the dimension, |
| 116 | + // the second D indicating whether the bits in those |
| 117 | + // positions are to be considered in checking if the tree |
| 118 | + // fits. There are 2**depth bitmasks for each dimension--each point |
| 119 | + // will match only one, and part of the inference task is determining which one. |
| 120 | + |
| 121 | + const treeSizeInBytes = |
| 122 | + 2 * this.params.D * (2 ** this.params.depth) / 8; |
| 123 | + |
| 124 | + const data = new Uint8Array(treeSizeInBytes * this.params.nTrees) |
| 125 | + this._forests = createSingletonBuffer( |
| 126 | + this.device, |
| 127 | + data, |
| 128 | + GPUBufferUsage.STORAGE |
| 129 | + ) |
| 130 | + } |
| 131 | + |
| 132 | + |
| 133 | + // Rather than actually bootstrap, we generate a single |
| 134 | + // list of 100,000 numbers drawn from a poisson distribution. |
| 135 | + // These serve as weights for draws with replacement; to |
| 136 | + // bootstrap any given record batch, we take a sequence of |
| 137 | + // numbers from the buffer with offset i. |
| 138 | + get bootstrapSamples() { |
| 139 | + if (this._bootstrapSamples) { |
| 140 | + return this._bootstrapSamples |
| 141 | + } else { |
| 142 | + const arr = new Uint8Array(100000) |
| 143 | + for (let i = 0; i < arr.length; i++) { |
| 144 | + arr[i] = poissonRandomNumber() |
| 145 | + } |
| 146 | + this._bootstrapSamples = createSingletonBuffer( |
| 147 | + this.device, |
| 148 | + arr, |
| 149 | + GPUBufferUsage.STORAGE |
| 150 | + ) |
| 151 | + return this._bootstrapSamples |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + |
| 156 | +} |
| 157 | + |
| 158 | + |
| 159 | +function poissonRandomNumber() : number { |
| 160 | + let p = 1.0; |
| 161 | + let k = 0; |
| 162 | + |
| 163 | + do { |
| 164 | + k++; |
| 165 | + p *= Math.random(); |
| 166 | + } while (p > 1/Math.E); |
| 167 | + |
| 168 | + return k - 1; |
| 169 | +} |
| 170 | + |
0 commit comments