Skip to content

Commit d4847fb

Browse files
committed
Merge remote-tracking branch 'origin/master' into feature/generate_swiss_roll
2 parents d5532b0 + 22badcb commit d4847fb

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Dataset | Title
3030
make_blobs | Generate isotropic Gaussian blobs for clustering. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html)
3131
make_moons | Make two interleaving half circles | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html)
3232
make_s_curve | Generate an S curve dataset. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_s_curve.html)
33+
make_circles | Make a large circle containing a smaller circle in 2d | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html])
3334
make_regression | Generate a random regression problem. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html])
3435
make_classification | Generate a random n-class classification problem. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html])
3536

src/sklearn.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,37 @@ function generate_s_curve(; n_samples::Int = 100,
8585
return convert(features, labels)
8686
end
8787

88+
"""
89+
function generate_circles(; n_samples::Int = 100,
90+
shuffle::Bool = true,
91+
noise::Float64 = 0.0,
92+
random_state::Union{Int, Nothing} = nothing,
93+
factor::Float64 = 0.8)::DataFrame
94+
Make a large circle containing a smaller circle in 2d. Sklearn interface to make_circles.
95+
# Arguments
96+
- `n_samples::Union{Int, Tuple{Int, Int}} = 100`: If int, it is the total number of points generated. For odd numbers, the inner circle will have one point more than the outer circle. If two-element tuple, number of points in outer circle and inner circle.
97+
- `shuffle::Bool = true`: Whether to shuffle the samples.
98+
- `noise::Union{Nothing, Float64} = nothing`: Standard deviation of Gaussian noise added to the data.
99+
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset shuffling and noise. Pass an int for reproducible output across multiple function calls.
100+
- `factor::Float64 = 0.8`: Scale factor between inner and outer circle.
101+
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html)
102+
103+
"""
104+
function generate_circles(; n_samples::Union{Int, Tuple{Int, Int}} = 100,
105+
shuffle::Bool = true,
106+
noise::Union{Nothing, Float64} = nothing,
107+
random_state::Union{Int, Nothing} = nothing,
108+
factor::Float64 = 0.8)::DataFrame
109+
110+
(features, labels) = datasets.make_circles( n_samples = n_samples,
111+
shuffle = shuffle,
112+
noise = noise,
113+
random_state = random_state,
114+
factor = factor)
115+
116+
return convert(features, labels)
117+
end
118+
88119
"""
89120
generate_regression(; n_samples::Int = 100,
90121
n_features::Int = 100,
@@ -124,7 +155,6 @@ function generate_regression(; n_samples::Int = 100,
124155
coef::Bool = false,
125156
random_state::Union{Int, Nothing}= nothing)
126157

127-
128158
(features, labels) = datasets.make_regression( n_samples = n_samples,
129159
n_features = n_features,
130160
n_informative = n_informative,
@@ -136,10 +166,8 @@ function generate_regression(; n_samples::Int = 100,
136166
shuffle = shuffle,
137167
coef = coef,
138168
random_state = random_state)
139-
140169

141170
return convert(features, labels)
142-
143171
end
144172

145173
"""
@@ -193,7 +221,6 @@ function generate_classification(; n_samples::Int = 100,
193221
shuffle::Bool = true,
194222
random_state::Union{Int, Nothing} = nothing)
195223

196-
197224
(features, labels) = datasets.make_classification( n_samples = n_samples,
198225
n_features = n_features,
199226
n_informative = n_informative,

test/runtests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ using Test
2727
@test size(data)[1] == samples
2828
@test size(data)[2] == 4
2929

30+
data = SyntheticDatasets.generate_circles(n_samples = samples)
31+
32+
@test size(data)[1] == samples
33+
@test size(data)[2] == 3
34+
3035
data = SyntheticDatasets.generate_regression(n_samples = samples,
3136
n_features = features,
3237
noise = 2.2,
@@ -46,6 +51,6 @@ using Test
4651
noise = 2.2,
4752
random_state = 5)
4853

49-
@test @show size(data)[1] == samples
54+
@test size(data)[1] == samples
5055
@test size(data)[2] == 4
5156
end

0 commit comments

Comments
 (0)