From 7761625ae089f13c9fb94d186b27e27ed8b3afc5 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Thu, 25 Apr 2024 19:26:22 -0400 Subject: [PATCH] StaticArrayInterface --- Project.toml | 1 + src/TriangularSolve.jl | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/Project.toml b/Project.toml index 15ef9e9..6dbde67 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" [compat] diff --git a/src/TriangularSolve.jl b/src/TriangularSolve.jl index a16cbb3..2846073 100644 --- a/src/TriangularSolve.jl +++ b/src/TriangularSolve.jl @@ -19,6 +19,7 @@ using Static using IfElse: ifelse using LoopVectorization using Polyester +using StaticArrayInterface const LPtr{T} = Core.LLVMPtr{T,0} _lptr(x::Ptr{T}) where {T} = reinterpret(LPtr{T}, x) @@ -718,8 +719,24 @@ struct Mat{T,ColMajor} <: AbstractMatrix{T} end Base.size(A::Mat) = (A.M, A.N) Base.axes(A::Mat) = (CloseOpen(A.M), CloseOpen(A.N)) +Base.strides(A::Mat{T,true}) where {T} = (1, getfield(A, :x)) +Base.strides(A::Mat{T,false}) where {T} = (getfield(A, :x), 1) Base.transpose(A::Mat{T,true}) where {T} = Mat{T,false}(A.p, A.x, A.N, A.M) Base.transpose(A::Mat{T,false}) where {T} = Mat{T,true}(A.p, A.x, A.N, A.M) +Base.pointer(A::Mat) = getfield(A, :p) +StaticArrayInterface.device(::Mat) = StaticArrayInterface.CPUPointer() +StaticArrayInterface.static_strides(A::Mat{T,true}) where {T} = + (static(1), getfield(A, :x)) +StaticArrayInterface.static_strides(A::Mat{T,false}) where {T} = + (getfield(A, :x), static(1)) +StaticArrayInterface.offsets(::Mat) = (static(0), static(0)) +StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,true}}) = (static(1), static(2)) +StaticArrayInterface.stride_rank(::Type{<:Mat{<:Any,false}}) = (static(2), static(1)) +StaticArrayInterface.contiguous_batch_size(::Type{<:Mat}) = static(0) +StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,true}}) = (static(true),static(false)) +StaticArrayInterface.dense_dims(::Type{<:Mat{<:Any,false}}) = (static(false),static(true)) +StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,true}}) = static(1) +StaticArrayInterface.contiguous_axis(::Type{<:Mat{<:Any,false}}) = static(2) @inline function Base.getindex( A::Mat{T,ColMajor}, i::Int,