Skip to content

Commit 7761625

Browse files
committed
StaticArrayInterface
1 parent 168d962 commit 7761625

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1212
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
1313
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
14+
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
1415
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1516

1617
[compat]

src/TriangularSolve.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using Static
1919
using IfElse: ifelse
2020
using LoopVectorization
2121
using Polyester
22+
using StaticArrayInterface
2223

2324
const LPtr{T} = Core.LLVMPtr{T,0}
2425
_lptr(x::Ptr{T}) where {T} = reinterpret(LPtr{T}, x)
@@ -718,8 +719,24 @@ struct Mat{T,ColMajor} <: AbstractMatrix{T}
718719
end
719720
Base.size(A::Mat) = (A.M, A.N)
720721
Base.axes(A::Mat) = (CloseOpen(A.M), CloseOpen(A.N))
722+
Base.strides(A::Mat{T,true}) where {T} = (1, getfield(A, :x))
723+
Base.strides(A::Mat{T,false}) where {T} = (getfield(A, :x), 1)
721724
Base.transpose(A::Mat{T,true}) where {T} = Mat{T,false}(A.p, A.x, A.N, A.M)
722725
Base.transpose(A::Mat{T,false}) where {T} = Mat{T,true}(A.p, A.x, A.N, A.M)
726+
Base.pointer(A::Mat) = getfield(A, :p)
727+
StaticArrayInterface.device(::Mat) = StaticArrayInterface.CPUPointer()
728+
StaticArrayInterface.static_strides(A::Mat{T,true}) where {T} =
729+
(static(1), getfield(A, :x))
730+
StaticArrayInterface.static_strides(A::Mat{T,false}) where {T} =
731+
(getfield(A, :x), static(1))
732+
StaticArrayInterface.offsets(::Mat) = (static(0), static(0))
733+
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,true}}) = (static(1), static(2))
734+
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,false}}) = (static(2), static(1))
735+
StaticArrayInterface.contiguous_batch_size(::Type{<:Mat}) = static(0)
736+
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,true}}) = (static(true),static(false))
737+
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,false}}) = (static(false),static(true))
738+
StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,true}}) = static(1)
739+
StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,false}}) = static(2)
723740
@inline function Base.getindex(
724741
A::Mat{T,ColMajor},
725742
i::Int,

0 commit comments

Comments
 (0)