Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit ed65ecd

Browse files
committed
complete legendre_filter
1 parent 60f1faa commit ed65ecd

File tree

6 files changed

+70
-12
lines changed

6 files changed

+70
-12
lines changed

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1111
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1212
GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea"
1313
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
14+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
1516
SpecialPolynomials = "a25cea48-d430-424a-8ee7-0d3ad3742e9e"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

src/NeuralOperators.jl

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using GeometricFlux
1212
using Statistics
1313
using Polynomials
1414
using SpecialPolynomials
15+
using LinearAlgebra
1516

1617
# kernels
1718
include("Transform/Transform.jl")

src/Transform/polynomials.jl

+13-10
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,25 @@ function chebyshev_ϕ_ψ(k)
128128
end
129129

130130
function legendre_filter(k)
131-
H0 = zeros(k, k)legendre
131+
H0 = zeros(k, k)
132132
H1 = zeros(k, k)
133133
G0 = zeros(k, k)
134134
G1 = zeros(k, k)
135135
ϕ, ψ1, ψ2 = legendre_ϕ_ψ(k)
136136

137-
# roots = Poly(legendre(k, 2*x-1)).all_roots()
138-
# x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
139-
# wm = 1/k/legendreDer(k,2*x_m-1)/eval_legendre(k-1,2*x_m-1)
137+
l = convert(Polynomial, gen_poly(Legendre, k))
138+
x_m = roots(l(Polynomial([-1, 2]))) # 2x-1
139+
m = 2 .* x_m .- 1
140+
wm = 1 ./ k ./ legendre_der.(k, m) ./ gen_poly(Legendre, k-1).(m)
140141

141-
# for ki in range(k):
142-
# for kpi in range(k):
143-
# H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
144-
# G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
145-
# H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum()
146-
# G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()
142+
for ki in 0:(k-1)
143+
for kpi in 0:(k-1)
144+
H0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].(x_m/2) .* ϕ[kpi+1].(x_m))
145+
G0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, x_m/2) .* ϕ[kpi+1].(x_m))
146+
H1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].((x_m.+1)/2) .* ϕ[kpi+1].(x_m))
147+
G1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m.+1)/2) .* ϕ[kpi+1].(x_m))
148+
end
149+
end
147150

148151
zero_out!(H0)
149152
zero_out!(H1)

src/Transform/utils.jl

+12-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ function ϕ_(ϕ_coefs; lb::Real=0., ub::Real=1.)
77
end
88

99
function ψ(ψ1, ψ2, i, inp)
10-
mask = (inp 0.5) * 1.0
11-
return ψ1[i](inp) * mask + ψ2[i](inp) * (1-mask)
10+
mask = (inp .> 0.5) .* 1.0
11+
return ψ1[i+1].(inp) .* mask .+ ψ2[i+1].(inp) .* mask
1212
end
1313

1414
zero_out!(x; tol=1e-8) = (x[abs.(x) .< tol] .= 0)
@@ -35,3 +35,13 @@ function proj_factor(a, b; complement::Bool=false)
3535
proj_ = sum(prod_ ./ r .* s)
3636
return proj_
3737
end
38+
39+
_legendre(k, x) = (2k+1) * gen_poly(Legendre, k)(x)
40+
41+
function legendre_der(k, x)
42+
out = 0
43+
for i in k-1:-2:-1
44+
out += _legendre(i, x)
45+
end
46+
return out
47+
end

test/polynomials.jl

+42
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,46 @@
6464
@test ψ2[3](1) -1.0941547380212384
6565
@test ψ2[3](2) 0.
6666
end
67+
68+
@testset "legendre_filter" begin
69+
H0, H1, G0, G1, Φ1, Φ2 = NeuralOperators.legendre_filter(3)
70+
71+
@test H0 [0.70710678 0. 0. ;
72+
-0.61237244 0.35355339 0. ;
73+
0. -0.6846532 0.1767767]
74+
@test H1 [0.70710678 0. 0. ;
75+
0.61237244 0.35355339 0. ;
76+
0. 0.6846532 0.1767767]
77+
@test G0 [0.35355339 0.61237244 0. ;
78+
0. 0.1767767 0.6846532 ;
79+
0. 0. 0.70710678]
80+
@test G1 [-0.35355339 0.61237244 0. ;
81+
0. -0.1767767 0.6846532 ;
82+
0. 0. -0.70710678]
83+
@test Φ1 == I(3)
84+
@test Φ2 == I(3)
85+
end
86+
87+
@testset "chebyshev_filter" begin
88+
# H0, H1, G0, G1, Φ1, Φ2 = NeuralOperators.chebyshev_filter(3)
89+
90+
# @test H0 ≈ [0.70710678 0. 0. ;
91+
# -0.5 0.35355339 0. ;
92+
# -0.25 -0.70710678 0.1767767]
93+
# @test H1 ≈ [0.70710678 0. 0. ;
94+
# 0.5 0.35355339 0. ;
95+
# -0.25 0.70710678 0.1767767]
96+
# @test G0 ≈ [0.60944614 0.77940383 0. ;
97+
# 0.66325172 1.02726613 1.14270252;
98+
# 0.61723435 0.90708619 1.1562954 ]
99+
# @test G1 ≈ [-0.60944614 0.77940383 0. ;
100+
# 0.66325172 -1.02726613 1.14270252;
101+
# -0.61723435 0.90708619 -1.1562954 ]
102+
# @test Φ1 ≈ [1. -0.40715364 -0.21440101;
103+
# -0.40715364 0.84839559 -0.44820615;
104+
# -0.21440101 -0.44820615 0.84002127]
105+
# @test Φ2 ≈ [1. 0.40715364 -0.21440101;
106+
# 0.40715364 0.84839559 0.44820615;
107+
# -0.21440101 0.44820615 0.84002127]
108+
end
67109
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using CUDA
33
using Flux
44
using GeometricFlux
55
using Graphs
6+
using LinearAlgebra
67
using Polynomials
78
using Zygote
89
using Test

0 commit comments

Comments
 (0)