Skip to content

Commit

Permalink
batch norm
Browse files Browse the repository at this point in the history
  • Loading branch information
harry7557558 committed Mar 14, 2024
1 parent 4ba6cd5 commit d278282
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 4 deletions.
65 changes: 64 additions & 1 deletion implicit3-rt/denoise.js
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,69 @@ function applyResidualDenoiser(model) {
}
}

function applyNormalizedResidualDenoiser(model) {
let gl = renderer.gl;

let programOutput = createShaderProgram(gl, null,
`#version 300 es
precision highp float;
uniform sampler2D uSrc;
out vec4 fragColor;
void main() {
ivec2 coord = ivec2(gl_FragCoord.xy);
vec3 c = texelFetch(uSrc, coord, 0).xyz;
c = pow(max(exp(c)-1.0, 0.0), vec3(2.2));
fragColor = vec4(c.xyz, 1.0);
}`);

var layerTempI = null;
var layerTempO = null;
function updateLayer() {
var w = Math.ceil(state.width/16)*16;
var h = Math.ceil(state.height/16)*16;
var oldLayer = layerTempI;
layerTempI = new Dnn.CNNLayer(gl, 3, w, h);
if (oldLayer) Dnn.destroyCnnLayer(gl, oldLayer);
oldLayer = layerTempO;
layerTempO = new Dnn.CNNLayer(gl, 3, w, h);
if (oldLayer) Dnn.destroyCnnLayer(gl, oldLayer);
};
updateLayer();
window.addEventListener("resize", function (event) {
setTimeout(updateLayer, 20);
});

renderer.denoiser = function(inputs, framebuffer) {
if (inputs.pixel !== 'framebuffer')
throw new Error("Unsupported NN input");
// load input
gl.bindTexture(gl.TEXTURE_2D, layerTempI.imgs[0].texture);
gl.copyTexImage2D(gl.TEXTURE_2D,
0, gl.RGBA32F, 0, 0, state.width, state.height, 0);
// normalize
var mean = Dnn.global_mean(gl, layerTempI);
var std = Dnn.global_std(gl, layerTempI);
Dnn.batch_norm_2d(gl, layerTempI, model.layers.input, 0.0, 1.0, mean, std);
// inference residual model
model.forward();
Dnn.add(gl, model.layers.input, model.layers.output, layerTempI);
// normalize back
for (var i = 0; i < mean.length; i++)
std[i] = 1.0 / std[i], mean[i] *= -std[i];
Dnn.batch_norm_2d(gl, layerTempI, layerTempO, 0.0, 1.0, mean, std);
// gamma transform
gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
gl.useProgram(programOutput);
setPositionBuffer(gl, programOutput);
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, layerTempO.imgs[0].texture);
gl.uniform1i(gl.getUniformLocation(programOutput, "uSrc"), 0);
gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);
}
}


function initDenoiserModel_runet1(params) {
let unet = new UNet1(3, 12, 16, 24, 32, params);
Expand Down Expand Up @@ -482,7 +545,7 @@ function initDenoiserModel_temp(params) {
window.addEventListener("resize", function (event) {
setTimeout(unet.updateLayers, 20);
});
applyResidualDenoiser(unet);
applyNormalizedResidualDenoiser(unet);
useDenoiser.denoisers['temp'] = renderer.denoiser;
}

Expand Down
4 changes: 2 additions & 2 deletions implicit3-rt/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class Model(torch.nn.Module):
pass

model = torch.load(
'../../Graphics/image/denoise/data_spirulae_5/resunet2gan_1_3.pth',
'../../Graphics/image/denoise/data_spirulae_5/resunet2gan_2_1.pth',
map_location=torch.device('cpu'))

state_dict = model.state_dict()
Expand Down Expand Up @@ -36,7 +36,7 @@ class Model(torch.nn.Module):
data = np.concatenate((data, data_))
print(key, tensor.shape, (amin, amax), sep='\t')

name = "runet2gan2"
name = "temp"
with open(f"denoise_models/denoise_{name}.json", 'w') as fp:
json.dump(info, fp)
data.tofile(f"denoise_models/denoise_{name}.bin")
179 changes: 178 additions & 1 deletion scripts/dnn.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,184 @@ Dnn.destroyCnnLayer = function(gl, layer) {
}


Dnn.global_mean = function(gl, buffer) {
const tile = 64;
if (!Dnn.programGlobalMean) {
Dnn.programGlobalMean = createShaderProgram(gl, null,
`#version 300 es
precision mediump float;
uniform sampler2D uSrc;
out vec4 fragColor;
void main() {
ivec2 ires = textureSize(uSrc, 0);
ivec2 tile = (ires+${tile}-1) / ${tile};
vec2 sc = vec2(ires) / float(${tile});
ivec2 xy = ivec2(gl_FragCoord.xy);
ivec2 pos0 = xy * tile;
ivec2 pos1 = min(pos0+tile, ires);
vec4 total = vec4(0);
for (int x = pos0.x; x < pos1.x; x++) {
vec4 s = vec4(0);
for (int y = pos0.y; y < pos1.y; y++)
s += texelFetch(uSrc, ivec2(x,y), 0);
total += s / sc.y;
}
fragColor = total / sc.x;
}`);
}
if (!Dnn.bufferGlobalMean) {
Dnn.bufferGlobalMean = createRenderTarget(gl, tile, tile, false, true, false);
}
let program = Dnn.programGlobalMean;
gl.useProgram(program);
gl.viewport(0, 0, buffer.w, buffer.h);
gl.disable(gl.BLEND);
var mean = new Array(buffer.n).fill(0.0);
for (var i = 0; i < buffer.n; i += 4) {
gl.bindFramebuffer(gl.FRAMEBUFFER, Dnn.bufferGlobalMean.framebuffer);
setPositionBuffer(gl, program);
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, buffer.imgs[i/4].texture);
gl.uniform1i(gl.getUniformLocation(program, "uSrc"), 0);
gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);
var pixels = new Float32Array(4*tile*tile);
gl.readPixels(0, 0, tile, tile, gl.RGBA, gl.FLOAT, pixels);
var total = [0.0, 0.0, 0.0, 0.0];
for (var _ = 0; _ < 4*tile*tile; _++)
total[_%4] += pixels[_];
for (var _ = i; _ < i+4 && _ < buffer.n; _++)
mean[_] = total[_-i] / (tile*tile);
}
return mean;
}


Dnn.global_std = function(gl, buffer) {
const tile = 64;
if (!Dnn.programGlobalSquaredMean) {
Dnn.programGlobalSquaredMean = createShaderProgram(gl, null,
`#version 300 es
precision mediump float;
uniform sampler2D uSrc;
out vec4 fragColor;
void main() {
ivec2 ires = textureSize(uSrc, 0);
ivec2 tile = (ires+${tile}-1) / ${tile};
vec2 sc = vec2(ires) / float(${tile});
ivec2 xy = ivec2(gl_FragCoord.xy);
ivec2 pos0 = xy * tile;
ivec2 pos1 = min(pos0+tile, ires);
vec4 total = vec4(0);
for (int x = pos0.x; x < pos1.x; x++) {
vec4 s = vec4(0);
for (int y = pos0.y; y < pos1.y; y++) {
vec4 c = texelFetch(uSrc, ivec2(x,y), 0);
s += c*c;
}
total += s / sc.y;
}
fragColor = total / sc.x;
}`);
}
if (!Dnn.bufferGlobalSquaredMean) {
Dnn.bufferGlobalSquaredMean = createRenderTarget(gl, tile, tile, false, true, false);
}
let program = Dnn.programGlobalSquaredMean;
gl.useProgram(program);
gl.viewport(0, 0, buffer.w, buffer.h);
gl.disable(gl.BLEND);
var mean2 = new Array(buffer.n).fill(0.0);
for (var i = 0; i < buffer.n; i += 4) {
gl.bindFramebuffer(gl.FRAMEBUFFER, Dnn.bufferGlobalSquaredMean.framebuffer);
setPositionBuffer(gl, program);
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, buffer.imgs[i/4].texture);
gl.uniform1i(gl.getUniformLocation(program, "uSrc"), 0);
gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);
var pixels = new Float32Array(4*tile*tile);
gl.readPixels(0, 0, tile, tile, gl.RGBA, gl.FLOAT, pixels);
var total = [0.0, 0.0, 0.0, 0.0];
for (var _ = 0; _ < 4*tile*tile; _++)
total[_%4] += pixels[_];
for (var _ = i; _ < i+4 && _ < buffer.n; _++)
mean2[_] = total[_-i] / (tile*tile);
}
let mean = Dnn.global_mean(gl, buffer);
var std = new Array(buffer.n).fill(0.0);
for (var i = 0; i < buffer.n; i++)
std[i] = Math.sqrt(Math.max(mean2[i] - mean[i]*mean[i], 0.0));
return std;
}


Dnn.batch_norm_2d = function(
gl, buffer_in, buffer_out,
beta, gamma, mean=null, std=null, eps=1e-5
) {
if (buffer_in.n != buffer_out.n)
throw new Error("Input and output buffer sizes don't match.");
if (buffer_out.w != buffer_in.w || buffer_out.h != buffer_in.h)
throw new Error("Input and output buffer dimensions don't match.");
if (!Dnn.programBatchNorm2d) {
Dnn.programBatchNorm2d = createShaderProgram(gl, null,
`#version 300 es
precision mediump float;
uniform sampler2D uSrc;
out vec4 fragColor;
uniform vec4 slope;
uniform vec4 intercept;
void main() {
vec4 c = texelFetch(uSrc, ivec2(gl_FragCoord.xy), 0);
fragColor = slope * c + intercept;
}`);
}
if (mean === null)
mean = Dnn.global_mean(mean);
if (std === null)
std = Dnn.global_std(std);
let n = buffer_in.n;
if (typeof beta === 'number')
beta = new Array(n).fill(beta);
if (typeof gamma === 'number')
gamma = new Array(n).fill(gamma);
var slope = new Array(n);
var intercept = new Array(n);
for (var i = 0; i < n; i++) {
var m = 1.0 / Math.sqrt(std[i]*std[i]+eps);
slope[i] = m * gamma[i];
intercept[i] = beta[i] - m * mean[i];
}
while (n % 4) {
slope.push(0.0);
intercept.push(0.0);
n++;
}
let program = Dnn.programBatchNorm2d;
gl.useProgram(program);
gl.viewport(0, 0, buffer_in.w, buffer_in.h);
gl.disable(gl.BLEND);
for (var i = 0; i < n; i += 4) {
gl.bindFramebuffer(gl.FRAMEBUFFER, buffer_out.imgs[i/4].framebuffer);
setPositionBuffer(gl, program);
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, buffer_in.imgs[i/4].texture);
gl.uniform1i(gl.getUniformLocation(program, "uSrc"), 0);
gl.uniform4f(gl.getUniformLocation(program, "slope"),
slope[i], slope[i+1], slope[i+2], slope[i+3]);
gl.uniform4f(gl.getUniformLocation(program, "intercept"),
intercept[i], intercept[i+1], intercept[i+2], intercept[i+3]);
gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);
}
}


Dnn.Conv2d311 = function(
n_in, n_out, weights, biases = []
) {
Expand Down Expand Up @@ -132,7 +310,6 @@ Dnn.Conv2d311 = function(
this.weightTextureData);
}

// console.log(buffer_in.w*buffer_in.h);
let useWeightTexture = (buffer_in.w*buffer_in.h < 1e+4);
let program = useWeightTexture ?
Dnn.programConv2d311wt : Dnn.programConv2d311;
Expand Down

0 comments on commit d278282

Please sign in to comment.