Skip to content

Commit

Permalink
StaticArrayInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Apr 25, 2024
1 parent 168d962 commit 7761625
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"

[compat]
Expand Down
17 changes: 17 additions & 0 deletions src/TriangularSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using Static
using IfElse: ifelse
using LoopVectorization
using Polyester
using StaticArrayInterface

const LPtr{T} = Core.LLVMPtr{T,0}
_lptr(x::Ptr{T}) where {T} = reinterpret(LPtr{T}, x)
Expand Down Expand Up @@ -718,8 +719,24 @@ struct Mat{T,ColMajor} <: AbstractMatrix{T}
end
Base.size(A::Mat) = (A.M, A.N)
Base.axes(A::Mat) = (CloseOpen(A.M), CloseOpen(A.N))
Base.strides(A::Mat{T,true}) where {T} = (1, getfield(A, :x))
Base.strides(A::Mat{T,false}) where {T} = (getfield(A, :x), 1)
Base.transpose(A::Mat{T,true}) where {T} = Mat{T,false}(A.p, A.x, A.N, A.M)
Base.transpose(A::Mat{T,false}) where {T} = Mat{T,true}(A.p, A.x, A.N, A.M)
Base.pointer(A::Mat) = getfield(A, :p)
StaticArrayInterface.device(::Mat) = StaticArrayInterface.CPUPointer()
StaticArrayInterface.static_strides(A::Mat{T,true}) where {T} =
(static(1), getfield(A, :x))
StaticArrayInterface.static_strides(A::Mat{T,false}) where {T} =
(getfield(A, :x), static(1))
StaticArrayInterface.offsets(::Mat) = (static(0), static(0))
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,true}}) = (static(1), static(2))
StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,false}}) = (static(2), static(1))
StaticArrayInterface.contiguous_batch_size(::Type{<:Mat}) = static(0)
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,true}}) = (static(true),static(false))
StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,false}}) = (static(false),static(true))
StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,true}}) = static(1)
StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,false}}) = static(2)
@inline function Base.getindex(
A::Mat{T,ColMajor},
i::Int,
Expand Down

0 comments on commit 7761625

Please sign in to comment.