Skip to content
This repository was archived by the owner on Jun 22, 2021. It is now read-only.

Commit 4411140

Browse files
committed
implement scitype for "points on a manifold" #46
1 parent fd9ff43 commit 4411140

File tree

5 files changed

+45
-9
lines changed

5 files changed

+45
-9
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ version = "0.3.0"
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
88
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
99
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
10+
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
1011
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1112
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
1213
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1314

1415
[compat]
1516
CategoricalArrays = "^0.8"
1617
ColorTypes = "^0.9,^0.10"
18+
ManifoldsBase = "^0.9.5"
1719
PrettyTables = "^0.8,^0.9"
1820
ScientificTypes = "^1.0"
1921
Tables = "^1.0"

src/MLJScientificTypes.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@ module MLJScientificTypes
22

33
# Dependencies
44
using ScientificTypes
5-
using Tables, CategoricalArrays, ColorTypes, PrettyTables, Dates
5+
using Tables, CategoricalArrays, ColorTypes, PrettyTables, Dates,
6+
ManifoldsBase
67

78
# re-exports from ScientificTypes
89
export Scientific, Found, Unknown, Known, Finite, Infinite,
910
OrderedFactor, Multiclass, Count, Continuous, Textual,
1011
Binary, ColorImage, GrayImage, Image, Table,
1112
ScientificTimeType, ScientificDate, ScientificDateTime,
12-
ScientificTime
13+
ScientificTime, ManifoldPoint
1314
export scitype, scitype_union, elscitype, nonmissing, trait
1415

1516
# exports

src/convention/scitype.jl

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ function ST.scitype(A::CArr{T,N}, ::MLJ) where {T,N}
2626
return AbstractArray{S,N}
2727
end
2828

29+
# Manifold scitype
30+
31+
ST.scitype(::Tuple{Any,MT}) where MT<:ManifoldsBase.Manifold = ManifoldPoint{MT}
32+
2933
# Table scitype
3034

3135
function ST.scitype(X, ::MLJ, ::Val{:table}; kw...)
@@ -39,10 +43,21 @@ end
3943

4044
# Scitype for fast array broadcasting
4145

42-
ST.Scitype(::Type{<:Integer}, ::MLJ) = Count
43-
ST.Scitype(::Type{<:AbstractFloat}, ::MLJ) = Continuous
44-
ST.Scitype(::Type{<:AbstractString}, ::MLJ) = Textual
45-
ST.Scitype(::Type{<:TimeType}, ::MLJ) = ScientificTimeType
46-
ST.Scitype(::Type{<:Date}, ::MLJ) = ScientificDate
47-
ST.Scitype(::Type{<:Time}, ::MLJ) = ScientificTime
48-
ST.Scitype(::Type{<:DateTime}, ::MLJ) = ScientificDateTime
46+
const Point{MT} = Tuple{Any,MT}
47+
const Manifold = ManifoldsBase.Manifold
48+
49+
ST.Scitype(::Type{<:Integer}, ::MLJ) = Count
50+
ST.Scitype(::Type{<:AbstractFloat}, ::MLJ) = Continuous
51+
ST.Scitype(::Type{<:AbstractString}, ::MLJ) = Textual
52+
ST.Scitype(::Type{<:TimeType}, ::MLJ) = ScientificTimeType
53+
ST.Scitype(::Type{<:Date}, ::MLJ) = ScientificDate
54+
ST.Scitype(::Type{<:Time}, ::MLJ) = ScientificTime
55+
ST.Scitype(::Type{<:DateTime}, ::MLJ) = ScientificDateTime
56+
57+
# Next two lines don't work https://github.com/JuliaLang/julia/issues/37703 :
58+
# ST.Scitype(::Type{<:Point{MT}}, ::MLJ) where MT<:ManifoldsBase.Manifold =
59+
# ManifoldPoint{MT}
60+
61+
# TODO: Remove the following hack when above issue is resolved:
62+
ST.Scitype(T::Type{<:Point{<:Manifold}}, ::MLJ) = ManifoldPoint{last(T.types)}
63+

test/basic_tests.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,23 @@ end
102102
AbstractVector{Union{Missing,ScientificTimeType}}
103103
end
104104

105+
struct MySphere{N} <: ManifoldsBase.Manifold{ManifoldsBase.ℝ} where {N}
106+
radius::Float64
107+
end
108+
MySphere(radius, n) = MySphere{n}(radius)
109+
110+
@testset "manifold point" begin
111+
manifold1 = MySphere(1, 3)
112+
@test scitype(("some_point_representation", manifold1)) ==
113+
ManifoldPoint{MySphere{3}}
114+
v1 = [(rand(), manifold1) for _ in 1:4]
115+
@test elscitype(v1) == ManifoldPoint{MySphere{3}}
116+
@test scitype(v1) == AbstractVector{ManifoldPoint{MySphere{3}}}
117+
manifold2 = MySphere(1, 4)
118+
v2 = [(rand(), manifold2) for _ in 1:3]
119+
@test scitype(vcat(v1, v2)) <: AbstractVector{<:ManifoldPoint{<:MySphere}}
120+
end
121+
105122
@testset "Type coercion" begin
106123
X = (x=10:10:44, y=1:4, z=collect("abcd"))
107124
types = Dict(:x => Continuous, :z => Multiclass)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test, ScientificTypes, MLJScientificTypes, Random
22
using Tables, CategoricalArrays, CSV, DataFrames, ColorTypes
3+
import ManifoldsBase
34
using Dates
45

56
const Arr = AbstractArray

0 commit comments

Comments
 (0)