Skip to content

Commit 86e8379

Browse files
authored
ENH: better support for different eltype in coefs and BasisMatrix when calling funeval closes #40 (#41)
1 parent f7b7b54 commit 86e8379

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/basis_structure.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ mutable struct BasisMatrix{BST<:ABSR, TM<:AbstractMatrix}
1414
vals::Matrix{TM}
1515
end
1616

17+
Base.eltype(bm::BasisMatrix{BST,TM}) where {BST, TM} = eltype(TM)
18+
1719
Base.show(io::IO, b::BasisMatrix{BST}) where {BST} =
1820
print(io, "BasisMatrix{$BST} of order $(b.order)")
1921

src/interp.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ function _funeval(c, bs::BasisMatrix{Tensor}, order::AbstractMatrix{Int}) # fun
8989
# 99
9090
nx = prod([size(bs.vals[1, j], 1) for j=1:d])
9191

92-
f = Array{eltype(c),3}(nx, size(c, 2), kk) # 100
92+
_T = promote_type(eltype(c), eltype(bs))
93+
f = Array{_T,3}(nx, size(c, 2), kk) # 100
9394

9495
for i=1:kk
9596
f[:, :, i] = ckronx(bs.vals, c, order[i, :]) # 102
@@ -102,7 +103,8 @@ function _funeval(c, bs::BasisMatrix{Direct}, order::AbstractMatrix{Int}) # fun
102103
# 114 reverse the order of evaluation: B(d)xB(d-1)x...xB(1)
103104
order = flipdim(order .+ (size(bs.vals, 1)*(0:d-1)' - bs.order+1), 2)
104105

105-
f = Array{eltype(c),3}(size(bs.vals[1], 1), size(c, 2), kk) # 116
106+
_T = promote_type(eltype(c), eltype(bs))
107+
f = Array{_T,3}(size(bs.vals[1], 1), size(c, 2), kk) # 116
106108

107109
for i in 1:kk
108110
f[:, :, i] = cdprodx(bs.vals, c, order[i, :]) # 118
@@ -113,7 +115,9 @@ end
113115
function _funeval(c, bs::BasisMatrix{Expanded}, order::AbstractMatrix{Int}) # funeval3
114116
nx = size(bs.vals[1], 1)
115117
kk = size(order, 1)
116-
f = Array{promote_type(eltype(c),eltype(bs.vals[1])),3}(nx, size(c, 2), kk)
118+
119+
_T = promote_type(eltype(c), eltype(bs))
120+
f = Array{_T,3}(nx, size(c, 2), kk)
117121
for i=1:kk
118122
this_order = order[i, :]
119123
ind = findfirst(x->bs.order[x, :] == this_order, 1:kk)

0 commit comments

Comments
 (0)