@@ -19,6 +19,7 @@ using Static
19
19
using IfElse: ifelse
20
20
using LoopVectorization
21
21
using Polyester
22
+ using StaticArrayInterface
22
23
23
24
const LPtr{T} = Core. LLVMPtr{T,0 }
24
25
_lptr (x:: Ptr{T} ) where {T} = reinterpret (LPtr{T}, x)
@@ -718,8 +719,24 @@ struct Mat{T,ColMajor} <: AbstractMatrix{T}
718
719
end
719
720
Base. size (A:: Mat ) = (A. M, A. N)
720
721
Base. axes (A:: Mat ) = (CloseOpen (A. M), CloseOpen (A. N))
722
+ Base. strides (A:: Mat{T,true} ) where {T} = (1 , getfield (A, :x ))
723
+ Base. strides (A:: Mat{T,false} ) where {T} = (getfield (A, :x ), 1 )
721
724
Base. transpose (A:: Mat{T,true} ) where {T} = Mat {T,false} (A. p, A. x, A. N, A. M)
722
725
Base. transpose (A:: Mat{T,false} ) where {T} = Mat {T,true} (A. p, A. x, A. N, A. M)
726
+ Base. pointer (A:: Mat ) = getfield (A, :p )
727
+ StaticArrayInterface. device (:: Mat ) = StaticArrayInterface. CPUPointer ()
728
+ StaticArrayInterface. static_strides (A:: Mat{T,true} ) where {T} =
729
+ (static (1 ), getfield (A, :x ))
730
+ StaticArrayInterface. static_strides (A:: Mat{T,false} ) where {T} =
731
+ (getfield (A, :x ), static (1 ))
732
+ StaticArrayInterface. offsets (:: Mat ) = (static (0 ), static (0 ))
733
+ StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,true}} ) = (static (1 ), static (2 ))
734
+ StaticArrayInterface. stride_rank (:: Type{<:Mat{<:Any,false}} ) = (static (2 ), static (1 ))
735
+ StaticArrayInterface. contiguous_batch_size (:: Type{<:Mat} ) = static (0 )
736
+ StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,true}} ) = (static (true ),static (false ))
737
+ StaticArrayInterface. dense_dims (:: Type{<:Mat{<:Any,false}} ) = (static (false ),static (true ))
738
+ StaticArrayInterface. contiguous_axis (:: Type{<:Mat{<:Any,true}} ) = static (1 )
739
+ StaticArrayInterface. contiguous_axis (:: Type{<:Mat{<:Any,false}} ) = static (2 )
723
740
@inline function Base. getindex (
724
741
A:: Mat{T,ColMajor} ,
725
742
i:: Int ,
0 commit comments