From c92d3ed1eca77ea61b756864bded99e6f42dc878 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Sat, 19 Feb 2022 15:59:57 +0530 Subject: [PATCH] Some tests for forward pass and gradients --- test/deeponet.jl | 19 ++++++++++++++++++- test/fourierlayer.jl | 25 ++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/test/deeponet.jl b/test/deeponet.jl index 0f3a482..84839e0 100644 --- a/test/deeponet.jl +++ b/test/deeponet.jl @@ -14,4 +14,21 @@ using Test, Random, Flux # Accept only Int as architecture parameters @test_throws MethodError DeepONet((32.5,64,72), (24,48,72), σ, tanh) @test_throws MethodError DeepONet((32,64,72), (24.1,48,72)) -end \ No newline at end of file +end + +#Just the first 16 datapoints from the Burgers' equation dataset +a = [0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755, 0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651, 0.81943734, 0.81737952, 0.8152405, 0.81302771] +sensors = collect(range(0, 1, length=16))' + +model = DeepONet((16, 22, 30), (1, 16, 24, 30), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false) + +model(a,sensors) + +#forward pass +@test size(model(a, sensors)) == (1, 16) + +mgrad = Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors) + +#gradients +@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[1]) +@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[2]) diff --git a/test/fourierlayer.jl b/test/fourierlayer.jl index 89bfa30..f6f2567 100644 --- a/test/fourierlayer.jl +++ b/test/fourierlayer.jl @@ -28,4 +28,27 @@ using Test, Random, Flux # Test max amount of modes @test_throws AssertionError FourierLayer(100, 100, 100, 60, σ) @test_throws AssertionError FourierLayer(100, 100, 100, 60) -end \ No newline at end of file +end + +#Just the first 16 data points from Burgers' equation dataset +xtrain = Float32[0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755, 0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651, 0.81943734, 0.81737952, 0.8152405, 0.81302771] +grid = Float32.(collect(range(0, 1, length=16))') + +x = cat(reshape(xtrain,(1,16,1)), + reshape(repeat(grid,1),(1,16,1)); + dims=3) + +x = permutedims(x,(3,2,1)) +layer = FourierLayer(64, 64, 16, 8, gelu, bias_fourier=false) +model = Chain(Dense(2,64;bias=false), layer, layer, layer, layer, + Dense(64,2;bias=false)) + +model(x) + +#forward pass +@test size(model(x)) == (2, 16, 1) + +Flux.Zygote.gradient((x)->sum(model(x)), x) + +#gradient test +@test !iszero(Flux.Zygote.gradient((x)->sum(model(x)), x)[1])