Skip to content

Commit a876349

Browse files
Added generate_hastie_10_2
1 parent ecf7441 commit a876349

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

src/sklearn.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,22 @@ function generate_swiss_roll(; n_samples::Int = 100,
361361

362362
return convert(features, labels)
363363
end
364+
365+
366+
"""
367+
function generate_hastie_10_2(; n_samples::Int = 12000,
368+
random_state::Union{Int,Nothing} = nothing)
369+
Generates data for binary classification used in Hastie et al. 2009, Example 10.2.
370+
#Arguments
371+
- `n_samples::Int = 100`: The number of samples..
372+
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls. See Glossary.
373+
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_hastie_10_2.html)
374+
"""
375+
function generate_hastie_10_2(; n_samples::Int = 12000,
376+
random_state::Union{Int,Nothing} = nothing)
377+
378+
(features, labels) = datasets.make_hastie_10_2( n_samples = n_samples,
379+
random_state = random_state)
380+
381+
return convert(features, labels)
382+
end

test/runtests.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,18 @@ using Test
7171

7272
@test size(data)[1] == samples
7373
@test size(data)[2] == features
74-
75-
data = SyntheticDatasets.generate_swiss_roll(n_samples =samples,
74+
75+
data = SyntheticDatasets.generate_swiss_roll(n_samples = samples,
7676
noise = 2.2,
7777
random_state = 5)
7878

7979
@test size(data)[1] == samples
8080
@test size(data)[2] == 4
8181
end
82+
83+
data = SyntheticDatasets.generate_hastie_10_2(n_samples = samples,
84+
random_state = 5)
85+
86+
@test size(data)[1] == samples
87+
@test size(data)[2] == 11
88+
end

0 commit comments

Comments
 (0)