Skip to content

Commit 22ae2c6

Browse files
committed
Merge branch 'master' into feature/generate_low_rank_matrix_2
2 parents ec70cec + 34ff75b commit 22ae2c6

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

src/sklearn.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,25 @@ function generate_low_rank_matrix(; n_samples::Int = 100,
268268
random_state = random_state)
269269
return features
270270
end
271+
272+
"""
273+
function generate_swiss_roll(; n_samples::Int = 100,
274+
noise::Float64 = 0.0,
275+
random_state::Union{Int,Nothing} = nothing)
276+
Generate a swiss roll dataset.
277+
#Arguments
278+
- `n_samples::Int = 100`: The number of samples.
279+
- `noise::Float64 = 0.0 : Standard deviation of Gaussian noise added to the data.
280+
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls. See Glossary.
281+
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_swiss_roll.htmll)
282+
"""
283+
function generate_swiss_roll(; n_samples::Int = 100,
284+
noise::Float64 = 0.0,
285+
random_state::Union{Int,Nothing} = nothing)
286+
287+
(features, labels) = datasets.make_swiss_roll( n_samples = n_samples,
288+
noise = noise,
289+
random_state = random_state)
290+
291+
return convert(features, labels)
292+
end

test/runtests.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,13 @@ using Test
5353
tail_strength = 0.5,
5454
random_state = 5)
5555

56-
5756
@test size(data)[1] == samples
5857
@test size(data)[2] == features
58+
59+
data = SyntheticDatasets.generate_swiss_roll(n_samples =samples,
60+
noise = 2.2,
61+
random_state = 5)
62+
63+
@test size(data)[1] == samples
64+
@test size(data)[2] == 4
5965
end

0 commit comments

Comments
 (0)