Skip to content

Commit

Permalink
wip attention model
Browse files Browse the repository at this point in the history
  • Loading branch information
harry7557558 committed Mar 15, 2024
1 parent d278282 commit 94de717
Show file tree
Hide file tree
Showing 5 changed files with 491 additions and 107 deletions.
169 changes: 154 additions & 15 deletions implicit3-rt/denoise.js
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,17 @@ function UNet1(nin, n0, n1, n2, n3, params) {


function initDenoiserModel_unet1(params) {
let gl = renderer.gl;

let unet = new UNet1(3, 16, 24, 48, 64, params);
window.addEventListener("resize", function (event) {
setTimeout(unet.updateLayers, 20);
});
applyDenoiser(unet);
useDenoiser.denoisers['unet1'] = renderer.denoiser;
}


function applyDenoiser(model) {
let gl = renderer.gl;

let programOutput = createShaderProgram(gl, null,
`#version 300 es
Expand All @@ -223,23 +228,21 @@ function initDenoiserModel_unet1(params) {
renderer.denoiser = function(inputs, framebuffer) {
if (inputs.pixel !== 'framebuffer')
throw new Error("Unsupported NN input");
gl.bindTexture(gl.TEXTURE_2D, unet.layers.input.imgs[0].texture);
gl.bindTexture(gl.TEXTURE_2D, model.layers.input.imgs[0].texture);
gl.copyTexImage2D(gl.TEXTURE_2D,
0, gl.RGBA32F, 0, 0, state.width, state.height, 0);
unet.forward();
model.forward();
gl.disable(gl.BLEND);
gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
gl.useProgram(programOutput);
setPositionBuffer(gl, programOutput);
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, unet.layers.output.imgs[0].texture);
gl.bindTexture(gl.TEXTURE_2D, model.layers.output.imgs[0].texture);
gl.uniform1i(gl.getUniformLocation(programOutput, "uSrc"), 0);
gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);
}
useDenoiser.denoisers['unet1'] = renderer.denoiser;
}


function applyResidualDenoiser(model) {
let gl = renderer.gl;

Expand Down Expand Up @@ -322,16 +325,15 @@ function applyNormalizedResidualDenoiser(model) {
gl.copyTexImage2D(gl.TEXTURE_2D,
0, gl.RGBA32F, 0, 0, state.width, state.height, 0);
// normalize
var mean = Dnn.global_mean(gl, layerTempI);
var std = Dnn.global_std(gl, layerTempI);
Dnn.batch_norm_2d(gl, layerTempI, model.layers.input, 0.0, 1.0, mean, std);
var ms = Dnn.global_mean_and_std(gl, layerTempI);
Dnn.batch_norm_2d(gl, layerTempI, model.layers.input, 0.0, 1.0, ms.mean, ms.std);
// inference residual model
model.forward();
Dnn.add(gl, model.layers.input, model.layers.output, layerTempI);
// normalize back
for (var i = 0; i < mean.length; i++)
std[i] = 1.0 / std[i], mean[i] *= -std[i];
Dnn.batch_norm_2d(gl, layerTempI, layerTempO, 0.0, 1.0, mean, std);
for (var i = 0; i < ms.length; i++)
ms.std[i] = 1.0 / ms.std[i], ms.mean[i] *= -ms.std[i];
Dnn.batch_norm_2d(gl, layerTempI, layerTempO, 0.0, 1.0, ms.mean, ms.std);
// gamma transform
gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
gl.useProgram(programOutput);
Expand Down Expand Up @@ -540,12 +542,149 @@ function initDenoiserModel_runet2gan2(params) {
useDenoiser.denoisers['runet2gan2'] = renderer.denoiser;
}


function AttUNet1(nin, k1, k2, k3, k4, ko, params) {
let gl = renderer.gl;
let kc = k1+ko;
if (kc % 2) throw new Error("Fractional attention layer size.");
let enc1 = new Dnn.Conv2d311(nin, k1, params['enc1.weight'], params['enc1.bias']);
let div1 = new Dnn.Conv2d321(k1, k2, params['div1.weight'], params['div1.bias']);
let enc2 = new Dnn.Conv2d311(k2, k2, params['enc2.weight'], params['enc2.bias']);
let div2 = new Dnn.Conv2d321(k2, k3, params['div2.weight'], params['div2.bias']);
let enc3 = new Dnn.Conv2d311(k3, k3, params['enc3.weight'], params['enc3.bias']);
let div3 = new Dnn.Conv2d321(k3, k4, params['div3.weight'], params['div3.bias']);
let dec31 = new Dnn.Conv2d311(k4, k4, params['dec31.weight'], params['dec31.bias']);
let dec32 = new Dnn.Conv2d311(k4, k4, params['dec32.weight'], params['dec32.bias']);
let upc3 = new Dnn.ConvTranspose2D421(k4, k3, params['upc3.weight'], params['upc3.bias']);
let dec21 = new Dnn.Conv2d311(k3, k3, params['dec21.weight'], params['dec21.bias']);
let dec22 = new Dnn.Conv2d311(k3, k3, params['dec22.weight'], params['dec22.bias']);
let upc2 = new Dnn.ConvTranspose2D421(k3+k3, k2, params['upc2.weight'], params['upc2.bias']);
let dec11 = new Dnn.Conv2d311(k2, k2, params['dec11.weight'], params['dec11.bias']);
let dec12 = new Dnn.Conv2d311(k2, k2, params['dec12.weight'], params['dec12.bias']);
let upc1 = new Dnn.ConvTranspose2D421(k2+k2, ko, params['upc1.weight'], params['upc1.bias']);
let dec01 = new Dnn.Conv2d311(ko, ko, params['dec01.weight'], params['dec01.bias']);
let dec02 = new Dnn.Conv2d311(ko, ko, params['dec02.weight'], params['dec02.bias']);
let att1conv1 = new Dnn.Conv2d110(kc, 1, params['attention1.conv1.weight'], params['attention1.conv1.bias']);
let att1conv2 = new Dnn.Conv2d110(kc, kc/2, params['attention1.conv2.weight'], params['attention1.conv2.bias']);
let att1conv3 = { w: params['attention1.conv3.weight'], b: params['attention1.conv3.bias'] };
let att2conv1 = new Dnn.Conv2d110(kc, kc/2, params['attention2.conv1.weight'], params['attention2.conv1.bias']);
let att2conv2 = new Dnn.Conv2d110(kc, kc/2, params['attention2.conv2.weight'], params['attention2.conv2.bias']);
let convo = new Dnn.Conv2d311(kc, 3, params['convo.weight'], params['convo.bias']);

let layers = {};
function ul(key, n, scale) {
var w = Math.ceil(state.width/16)*16;
var h = Math.ceil(state.height/16)*16;
var oldLayer = layers[key];
layers[key] = new Dnn.CNNLayer(gl, n, w/scale, h/scale);
if (oldLayer) Dnn.destroyCnnLayer(gl, oldLayer);
};
this.layers = layers;
this.updateLayers = function() {
ul("input", nin, 1);
ul("e1", k1, 1);
ul("e1r", k1, 1); ul("d1", k2, 2);
ul("d1r", k2, 2); ul("e2", k2, 2);
ul("e2r", k2, 2); ul("d2", k3, 4);
ul("d2r", k3, 4); ul("e3", k3, 4);
ul("e3r", k3, 4); ul("d3", k4, 8);
ul("d3r", k4, 8); ul("d31", k4, 8); ul("d31r", k4, 8); ul("d32", k4, 8);
ul("d32r", k4, 8); ul("u3", k3, 4);
ul("u3r", k3, 4); ul("d21", k3, 4); ul("d21r", k3, 4); ul("d22", k3, 4);
ul("d22r", k3, 4); ul("u2", k2, 2);
ul("u2r", k2, 2); ul("d11", k2, 2); ul("d11r", k2, 2); ul("d12", k2, 2);
ul("d12r", k2, 2); ul("u1", ko, 1);
ul("u1r", ko, 1); ul("d01", ko, 1); ul("d01r", ko, 1); ul("d02", ko, 1);
ul("d02r", ko, 1);
ul("att1x1", 1, 1); ul("att1x1a", 1, 1); ul("att1x2", kc/2, 1); ul("att1o", kc, 1);
ul("att2x1", kc/2, 1); ul("att2x2", kc/2, 1); ul("att2x3", 1, 1); ul("att2x3a", 1, 1); ul("att2o", kc, 1);
ul("output", 3, 1);
}
this.updateLayers();

function channelAttention(x) {
att1conv1.forward(gl, x, layers.att1x1);
Dnn.softmax2d(gl, layers.att1x1, layers.att1x1a);
att1conv2.forward(gl, x, layers.att1x2);
let x3 = Dnn.global_dot(gl, layers.att1x1a, layers.att1x2);
let x3a = new Array(kc);
for (var i = 0; i < kc; i++) {
var s = att1conv3.b[i];
for (var j = 0; j < kc/2; j++)
s += att1conv3.w[i*(kc/2)+j] * x3[j];
x3a[i] = 1.0 / (1.0+Math.exp(-s));
}
Dnn.batch_norm_2d(gl, x, layers.att1o, 0.0, x3a, 0.0, 1.0);
}

function spacialAttention(x) {
att2conv1.forward(gl, x, layers.att2x1);
let x1 = Dnn.global_mean(gl, layers.att2x1);
var sexp = 0.0;
for (var i = 0; i < kc/2; i++)
sexp += (x1[i] = Math.exp(x1[i]));
for (var i = 0; i < kc/2; i++)
x1[i] /= sexp;
att2conv2.forward(gl, x, layers.att2x2);
Dnn.channel_sum(gl, layers.att2x2, layers.att2x3, x1);
Dnn.sigmoid(gl, layers.att2x3, layers.att2x3a);
Dnn.mul(gl, x, layers.att2x3a, layers.att2o);
}

this.forward = function() {
enc1.forward(gl, layers.input, layers.e1);
Dnn.relu(gl, layers.e1, layers.e1r);
div1.forward(gl, layers.e1r, layers.d1);
Dnn.relu(gl, layers.d1, layers.d1r);
enc2.forward(gl, layers.d1r, layers.e2);
Dnn.relu(gl, layers.e2, layers.e2r);
div2.forward(gl, layers.e2r, layers.d2);
Dnn.relu(gl, layers.d2, layers.d2r);
enc3.forward(gl, layers.d2r, layers.e3);
Dnn.relu(gl, layers.e3, layers.e3r);
div3.forward(gl, layers.e3r, layers.d3);
Dnn.relu(gl, layers.d3, layers.d3r);
dec31.forward(gl, layers.d3r, layers.d31);
Dnn.relu(gl, layers.d31, layers.d31r);
dec32.forward(gl, layers.d31r, layers.d32);
Dnn.relu(gl, layers.d32, layers.d32r);
upc3.forward(gl, layers.d32r, layers.u3);
Dnn.relu(gl, layers.u3, layers.u3r);
dec21.forward(gl, layers.u3r, layers.d21);
Dnn.relu(gl, layers.d21, layers.d21r);
dec22.forward(gl, layers.d21r, layers.d22);
Dnn.relu(gl, layers.d22, layers.d22r);
upc2.forward(gl, Dnn.shallowConcat(layers.e3, layers.d22r), layers.u2);
Dnn.relu(gl, layers.u2, layers.u2r);
dec11.forward(gl, layers.u2r, layers.d11);
Dnn.relu(gl, layers.d11, layers.d11r);
dec12.forward(gl, layers.d11r, layers.d12);
Dnn.relu(gl, layers.d12, layers.d12r);
upc1.forward(gl, Dnn.shallowConcat(layers.e2, layers.d12r), layers.u1);
Dnn.relu(gl, layers.u1, layers.u1r);
dec01.forward(gl, layers.u1r, layers.d01);
Dnn.relu(gl, layers.d01, layers.d01r);
dec02.forward(gl, layers.d01r, layers.d02);
Dnn.relu(gl, layers.d02, layers.d02r);
channelAttention(Dnn.shallowConcat(layers.e1, layers.d02r));
// layers.output = layers.att1x1a; // visualize attention map
spacialAttention(layers.att1o);
// layers.output = layers.att2o;
convo.forward(gl, layers.att2o, layers.output);
};
}



function initDenoiserModel_temp(params) {
let unet = new UNet2(3, 12, 16, 24, 32, params);
// let unet = new UNet2(3, 12, 16, 24, 32, params);
let unet = new AttUNet1(3, 12, 16, 24, 32, 12, params);
window.addEventListener("resize", function (event) {
setTimeout(unet.updateLayers, 20);
});
applyNormalizedResidualDenoiser(unet);
// applyNormalizedResidualDenoiser(unet);
applyResidualDenoiser(unet);
// applyDenoiser(unet);
useDenoiser.denoisers['temp'] = renderer.denoiser;
}

Expand Down
12 changes: 10 additions & 2 deletions implicit3-rt/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
import numpy as np
import json

class AttentionChannelOnly(torch.nn.Module):
pass

class AttentionSpacialOnly(torch.nn.Module):
pass

class Model(torch.nn.Module):
pass

model = torch.load(
'../../Graphics/image/denoise/data_spirulae_5/resunet2gan_2_1.pth',
'../../Graphics/image/denoise/data_spirulae_5/attunet1_1_1.pth',
map_location=torch.device('cpu'))

state_dict = model.state_dict()
Expand All @@ -24,6 +30,8 @@ class Model(torch.nn.Module):
amin, amax = np.amin(tensor), np.amax(tensor)
vmin, vmax = -2**(nbit-1)+0.1, 2**(nbit-1)-1.1
m = (amax-amin) / (vmax-vmin)
if m == 0.0:
m = 1.0
b = amin - m * vmin
item = {
'shape': [*tensor.shape],
Expand All @@ -38,5 +46,5 @@ class Model(torch.nn.Module):

name = "temp"
with open(f"denoise_models/denoise_{name}.json", 'w') as fp:
json.dump(info, fp)
json.dump(info, fp, separators=(',', ':'))
data.tofile(f"denoise_models/denoise_{name}.bin")
2 changes: 1 addition & 1 deletion implicit3-rt/script.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 94de717

Please sign in to comment.