Skip to content

Commit e24b675

Browse files
Feature/generate_swiss_roll
1 parent aa9696d commit e24b675

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

src/sklearn.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,4 +211,27 @@ function generate_classification(; n_samples::Int = 100,
211211
random_state = random_state)
212212

213213
return convert(features, labels)
214-
end
214+
end
215+
"""
216+
function generate_swiss_roll(; n_samples::Int = 100,
217+
noise::Float64 = 0.0,
218+
random_state::Union{Int,Nothing} = nothing)
219+
Generate a swiss roll dataset.
220+
#Arguments
221+
- `n_samples::Int = 100`: The number of samples.
222+
- `noise::Float64 = 0.0 : Standard deviation of Gaussian noise added to the data.
223+
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls. See Glossary.
224+
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_swiss_roll.htmll)
225+
"""
226+
function generate_swiss_roll(; n_samples::Int = 100,
227+
noise::Float64 = 0.0,
228+
random_state::Union{Int,Nothing} = nothing)
229+
230+
231+
(features, labels) = datasets.make_swiss_roll( n_samples = n_samples,
232+
noise = noise,
233+
random_state = random_state)
234+
235+
return convert(features, labels)
236+
237+
end

test/runtests.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,17 @@ using Test
4040
n_features = features,
4141
n_classes = 1)
4242

43-
43+
4444
@test size(data)[1] == samples
4545
@test size(data)[2] == features + 1
4646

47-
end
47+
end
48+
data = SyntheticDatasets.generate_swiss_roll(n_samples =samples,
49+
noise = 2.2,
50+
random_state = 5)
51+
52+
@test @show size(data)[1] == samples
53+
@test size(data)[2] == 4
54+
55+
56+
end

0 commit comments

Comments
 (0)