diff --git a/implicit3-rt/denoise.js b/implicit3-rt/denoise.js index 9dd0e28..91ae243 100644 --- a/implicit3-rt/denoise.js +++ b/implicit3-rt/denoise.js @@ -280,6 +280,69 @@ function applyResidualDenoiser(model) { } } +function applyNormalizedResidualDenoiser(model) { + let gl = renderer.gl; + + let programOutput = createShaderProgram(gl, null, + `#version 300 es + precision highp float; + + uniform sampler2D uSrc; + + out vec4 fragColor; + void main() { + ivec2 coord = ivec2(gl_FragCoord.xy); + vec3 c = texelFetch(uSrc, coord, 0).xyz; + c = pow(max(exp(c)-1.0, 0.0), vec3(2.2)); + fragColor = vec4(c.xyz, 1.0); + }`); + + var layerTempI = null; + var layerTempO = null; + function updateLayer() { + var w = Math.ceil(state.width/16)*16; + var h = Math.ceil(state.height/16)*16; + var oldLayer = layerTempI; + layerTempI = new Dnn.CNNLayer(gl, 3, w, h); + if (oldLayer) Dnn.destroyCnnLayer(gl, oldLayer); + oldLayer = layerTempO; + layerTempO = new Dnn.CNNLayer(gl, 3, w, h); + if (oldLayer) Dnn.destroyCnnLayer(gl, oldLayer); + }; + updateLayer(); + window.addEventListener("resize", function (event) { + setTimeout(updateLayer, 20); + }); + + renderer.denoiser = function(inputs, framebuffer) { + if (inputs.pixel !== 'framebuffer') + throw new Error("Unsupported NN input"); + // load input + gl.bindTexture(gl.TEXTURE_2D, layerTempI.imgs[0].texture); + gl.copyTexImage2D(gl.TEXTURE_2D, + 0, gl.RGBA32F, 0, 0, state.width, state.height, 0); + // normalize + var mean = Dnn.global_mean(gl, layerTempI); + var std = Dnn.global_std(gl, layerTempI); + Dnn.batch_norm_2d(gl, layerTempI, model.layers.input, 0.0, 1.0, mean, std); + // inference residual model + model.forward(); + Dnn.add(gl, model.layers.input, model.layers.output, layerTempI); + // normalize back + for (var i = 0; i < mean.length; i++) + std[i] = 1.0 / std[i], mean[i] *= -std[i]; + Dnn.batch_norm_2d(gl, layerTempI, layerTempO, 0.0, 1.0, mean, std); + // gamma transform + gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); + gl.useProgram(programOutput); + setPositionBuffer(gl, programOutput); + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_2D, layerTempO.imgs[0].texture); + gl.uniform1i(gl.getUniformLocation(programOutput, "uSrc"), 0); + gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4); + } +} + function initDenoiserModel_runet1(params) { let unet = new UNet1(3, 12, 16, 24, 32, params); @@ -482,7 +545,7 @@ function initDenoiserModel_temp(params) { window.addEventListener("resize", function (event) { setTimeout(unet.updateLayers, 20); }); - applyResidualDenoiser(unet); + applyNormalizedResidualDenoiser(unet); useDenoiser.denoisers['temp'] = renderer.denoiser; } diff --git a/implicit3-rt/export_model.py b/implicit3-rt/export_model.py index f0e92bd..1d18de5 100644 --- a/implicit3-rt/export_model.py +++ b/implicit3-rt/export_model.py @@ -6,7 +6,7 @@ class Model(torch.nn.Module): pass model = torch.load( - '../../Graphics/image/denoise/data_spirulae_5/resunet2gan_1_3.pth', + '../../Graphics/image/denoise/data_spirulae_5/resunet2gan_2_1.pth', map_location=torch.device('cpu')) state_dict = model.state_dict() @@ -36,7 +36,7 @@ class Model(torch.nn.Module): data = np.concatenate((data, data_)) print(key, tensor.shape, (amin, amax), sep='\t') -name = "runet2gan2" +name = "temp" with open(f"denoise_models/denoise_{name}.json", 'w') as fp: json.dump(info, fp) data.tofile(f"denoise_models/denoise_{name}.bin") diff --git a/scripts/dnn.js b/scripts/dnn.js index b49c7b8..bef69da 100644 --- a/scripts/dnn.js +++ b/scripts/dnn.js @@ -63,6 +63,184 @@ Dnn.destroyCnnLayer = function(gl, layer) { } +Dnn.global_mean = function(gl, buffer) { + const tile = 64; + if (!Dnn.programGlobalMean) { + Dnn.programGlobalMean = createShaderProgram(gl, null, + `#version 300 es + precision mediump float; + + uniform sampler2D uSrc; + out vec4 fragColor; + + void main() { + ivec2 ires = textureSize(uSrc, 0); + ivec2 tile = (ires+${tile}-1) / ${tile}; + vec2 sc = vec2(ires) / float(${tile}); + ivec2 xy = ivec2(gl_FragCoord.xy); + ivec2 pos0 = xy * tile; + ivec2 pos1 = min(pos0+tile, ires); + vec4 total = vec4(0); + for (int x = pos0.x; x < pos1.x; x++) { + vec4 s = vec4(0); + for (int y = pos0.y; y < pos1.y; y++) + s += texelFetch(uSrc, ivec2(x,y), 0); + total += s / sc.y; + } + fragColor = total / sc.x; + }`); + } + if (!Dnn.bufferGlobalMean) { + Dnn.bufferGlobalMean = createRenderTarget(gl, tile, tile, false, true, false); + } + let program = Dnn.programGlobalMean; + gl.useProgram(program); + gl.viewport(0, 0, buffer.w, buffer.h); + gl.disable(gl.BLEND); + var mean = new Array(buffer.n).fill(0.0); + for (var i = 0; i < buffer.n; i += 4) { + gl.bindFramebuffer(gl.FRAMEBUFFER, Dnn.bufferGlobalMean.framebuffer); + setPositionBuffer(gl, program); + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_2D, buffer.imgs[i/4].texture); + gl.uniform1i(gl.getUniformLocation(program, "uSrc"), 0); + gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4); + var pixels = new Float32Array(4*tile*tile); + gl.readPixels(0, 0, tile, tile, gl.RGBA, gl.FLOAT, pixels); + var total = [0.0, 0.0, 0.0, 0.0]; + for (var _ = 0; _ < 4*tile*tile; _++) + total[_%4] += pixels[_]; + for (var _ = i; _ < i+4 && _ < buffer.n; _++) + mean[_] = total[_-i] / (tile*tile); + } + return mean; +} + + +Dnn.global_std = function(gl, buffer) { + const tile = 64; + if (!Dnn.programGlobalSquaredMean) { + Dnn.programGlobalSquaredMean = createShaderProgram(gl, null, + `#version 300 es + precision mediump float; + + uniform sampler2D uSrc; + out vec4 fragColor; + + void main() { + ivec2 ires = textureSize(uSrc, 0); + ivec2 tile = (ires+${tile}-1) / ${tile}; + vec2 sc = vec2(ires) / float(${tile}); + ivec2 xy = ivec2(gl_FragCoord.xy); + ivec2 pos0 = xy * tile; + ivec2 pos1 = min(pos0+tile, ires); + vec4 total = vec4(0); + for (int x = pos0.x; x < pos1.x; x++) { + vec4 s = vec4(0); + for (int y = pos0.y; y < pos1.y; y++) { + vec4 c = texelFetch(uSrc, ivec2(x,y), 0); + s += c*c; + } + total += s / sc.y; + } + fragColor = total / sc.x; + }`); + } + if (!Dnn.bufferGlobalSquaredMean) { + Dnn.bufferGlobalSquaredMean = createRenderTarget(gl, tile, tile, false, true, false); + } + let program = Dnn.programGlobalSquaredMean; + gl.useProgram(program); + gl.viewport(0, 0, buffer.w, buffer.h); + gl.disable(gl.BLEND); + var mean2 = new Array(buffer.n).fill(0.0); + for (var i = 0; i < buffer.n; i += 4) { + gl.bindFramebuffer(gl.FRAMEBUFFER, Dnn.bufferGlobalSquaredMean.framebuffer); + setPositionBuffer(gl, program); + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_2D, buffer.imgs[i/4].texture); + gl.uniform1i(gl.getUniformLocation(program, "uSrc"), 0); + gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4); + var pixels = new Float32Array(4*tile*tile); + gl.readPixels(0, 0, tile, tile, gl.RGBA, gl.FLOAT, pixels); + var total = [0.0, 0.0, 0.0, 0.0]; + for (var _ = 0; _ < 4*tile*tile; _++) + total[_%4] += pixels[_]; + for (var _ = i; _ < i+4 && _ < buffer.n; _++) + mean2[_] = total[_-i] / (tile*tile); + } + let mean = Dnn.global_mean(gl, buffer); + var std = new Array(buffer.n).fill(0.0); + for (var i = 0; i < buffer.n; i++) + std[i] = Math.sqrt(Math.max(mean2[i] - mean[i]*mean[i], 0.0)); + return std; +} + + +Dnn.batch_norm_2d = function( + gl, buffer_in, buffer_out, + beta, gamma, mean=null, std=null, eps=1e-5 +) { + if (buffer_in.n != buffer_out.n) + throw new Error("Input and output buffer sizes don't match."); + if (buffer_out.w != buffer_in.w || buffer_out.h != buffer_in.h) + throw new Error("Input and output buffer dimensions don't match."); + if (!Dnn.programBatchNorm2d) { + Dnn.programBatchNorm2d = createShaderProgram(gl, null, + `#version 300 es + precision mediump float; + + uniform sampler2D uSrc; + out vec4 fragColor; + + uniform vec4 slope; + uniform vec4 intercept; + + void main() { + vec4 c = texelFetch(uSrc, ivec2(gl_FragCoord.xy), 0); + fragColor = slope * c + intercept; + }`); + } + if (mean === null) + mean = Dnn.global_mean(mean); + if (std === null) + std = Dnn.global_std(std); + let n = buffer_in.n; + if (typeof beta === 'number') + beta = new Array(n).fill(beta); + if (typeof gamma === 'number') + gamma = new Array(n).fill(gamma); + var slope = new Array(n); + var intercept = new Array(n); + for (var i = 0; i < n; i++) { + var m = 1.0 / Math.sqrt(std[i]*std[i]+eps); + slope[i] = m * gamma[i]; + intercept[i] = beta[i] - m * mean[i]; + } + while (n % 4) { + slope.push(0.0); + intercept.push(0.0); + n++; + } + let program = Dnn.programBatchNorm2d; + gl.useProgram(program); + gl.viewport(0, 0, buffer_in.w, buffer_in.h); + gl.disable(gl.BLEND); + for (var i = 0; i < n; i += 4) { + gl.bindFramebuffer(gl.FRAMEBUFFER, buffer_out.imgs[i/4].framebuffer); + setPositionBuffer(gl, program); + gl.activeTexture(gl.TEXTURE0); + gl.bindTexture(gl.TEXTURE_2D, buffer_in.imgs[i/4].texture); + gl.uniform1i(gl.getUniformLocation(program, "uSrc"), 0); + gl.uniform4f(gl.getUniformLocation(program, "slope"), + slope[i], slope[i+1], slope[i+2], slope[i+3]); + gl.uniform4f(gl.getUniformLocation(program, "intercept"), + intercept[i], intercept[i+1], intercept[i+2], intercept[i+3]); + gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4); + } +} + + Dnn.Conv2d311 = function( n_in, n_out, weights, biases = [] ) { @@ -132,7 +310,6 @@ Dnn.Conv2d311 = function( this.weightTextureData); } - // console.log(buffer_in.w*buffer_in.h); let useWeightTexture = (buffer_in.w*buffer_in.h < 1e+4); let program = useWeightTexture ? Dnn.programConv2d311wt : Dnn.programConv2d311;