Skip to content

Commit 7bb11ac

Browse files
authored
Merge pull request #26 from ATISLabs/feature/generate_hastie_10_2
[#1] Feature/generate_hastie_10_2
2 parents 2388f1b + a51a6c0 commit 7bb11ac

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ make_regression | Generate a random regression problem.
4040
make_classification | Generate a random n-class classification problem. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html)
4141
make_low_rank_matrix | Generate a mostly low rank matrix with bell-shaped singular values. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_low_rank_matrix.html)
4242
make_swiss_roll | Generate a swiss roll dataset. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_swiss_roll.html)
43+
make_hastie_10_2 | Generates data for binary classification used in Hastie et al. |[link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_hastie_10_2.html)
4344
make_gaussian_quantiles | Generate a swiss roll dataset. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_gaussian_quantiles.html)
4445

4546
**Disclaimer**: SyntheticDatasets.jl borrows code and documentation from

Diff for: src/sklearn.jl

+19-1
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,24 @@ function generate_swiss_roll(; n_samples::Int = 100,
362362
return convert(features, labels)
363363
end
364364

365+
"""
366+
function generate_hastie_10_2(; n_samples::Int = 12000,
367+
random_state::Union{Int,Nothing} = nothing)
368+
Generates data for binary classification used in Hastie et al. 2009, Example 10.2.
369+
#Arguments
370+
- `n_samples::Int = 100`: The number of samples..
371+
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls. See Glossary.
372+
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_hastie_10_2.html)
373+
"""
374+
function generate_hastie_10_2(; n_samples::Int = 12000,
375+
random_state::Union{Int,Nothing} = nothing)
376+
377+
(features, labels) = datasets.make_hastie_10_2( n_samples = n_samples,
378+
random_state = random_state)
379+
380+
return convert(features, labels)
381+
end
382+
365383
"""
366384
function generate_gaussian_quantiles(; mean::Array{<:Union{Number, Nothing}, 1} = [nothing],
367385
cov::Float64 = 1,
@@ -398,5 +416,5 @@ function generate_gaussian_quantiles(; mean::Union{Array{<:Number, 1}, Nothing}
398416
shuffle = shuffle,
399417
random_state = random_state)
400418

401-
return convert(features, labels)
419+
return convert(features, labels)
402420
end

Diff for: test/runtests.jl

+18-12
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ using Test
4040
@test size(data)[1] == samples
4141
@test size(data)[2] == features + 1
4242

43-
data = SyntheticDatasets.generate_classification(n_samples = samples,
44-
n_features = features,
45-
n_classes = 1)
43+
data = SyntheticDatasets.generate_classification( n_samples = samples,
44+
n_features = features,
45+
n_classes = 1)
4646

4747
@test size(data)[1] == samples
4848
@test size(data)[2] == features + 1
4949

5050
data = SyntheticDatasets.generate_friedman1(n_samples = samples,
51-
n_features = features)
51+
n_features = features)
5252

5353
@test size(data)[1] == samples
5454
@test size(data)[2] == features + 1
@@ -63,26 +63,32 @@ using Test
6363
@test size(data)[1] == samples
6464
@test size(data)[2] == 5
6565

66-
data = SyntheticDatasets.generate_low_rank_matrix(n_samples = samples,
67-
n_features = features,
68-
effective_rank = 10,
69-
tail_strength = 0.5,
70-
random_state = 5)
66+
data = SyntheticDatasets.generate_low_rank_matrix( n_samples = samples,
67+
n_features = features,
68+
effective_rank = 10,
69+
tail_strength = 0.5,
70+
random_state = 5)
7171

7272
@test size(data)[1] == samples
7373
@test size(data)[2] == features
7474

75-
data = SyntheticDatasets.generate_swiss_roll(n_samples =samples,
75+
data = SyntheticDatasets.generate_swiss_roll(n_samples = samples,
7676
noise = 2.2,
7777
random_state = 5)
7878

7979
@test size(data)[1] == samples
8080
@test size(data)[2] == 4
8181

82-
data = SyntheticDatasets.generate_gaussian_quantiles(n_samples = samples,
83-
n_features = features,
82+
data = SyntheticDatasets.generate_hastie_10_2( n_samples = samples,
8483
random_state = 5)
8584

85+
@test size(data)[1] == samples
86+
@test size(data)[2] == 11
87+
88+
data = SyntheticDatasets.generate_gaussian_quantiles( n_samples = samples,
89+
n_features = features,
90+
random_state = 5)
91+
8692
@test size(data)[1] == samples
8793
@test size(data)[2] == features + 1
8894
end

0 commit comments

Comments
 (0)