Skip to content

Commit 2388f1b

Browse files
authored
Merge pull request #27 from ATISLabs/feature/generate_gaussian_quantiles
[#1]feature/generate_gaussian_quantiles
2 parents e09eaa2 + 0fabb49 commit 2388f1b

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ make_regression | Generate a random regression problem.
4040
make_classification | Generate a random n-class classification problem. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html)
4141
make_low_rank_matrix | Generate a mostly low rank matrix with bell-shaped singular values. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_low_rank_matrix.html)
4242
make_swiss_roll | Generate a swiss roll dataset. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_swiss_roll.html)
43+
make_gaussian_quantiles | Generate a swiss roll dataset. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_gaussian_quantiles.html)
4344

4445
**Disclaimer**: SyntheticDatasets.jl borrows code and documentation from
4546
[scikit-learn](https://scikit-learn.org/stable/modules/classes.html#samples-generator) in the dataset module, but *it is not an official part

src/matlab.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ function generate_twospirals(; n_samples::Int = 2000,
3030
labels = [zeros(Int, N1); ones(Int, N1)]
3131

3232
return convert(features, labels);
33-
end
33+
end

src/sklearn.jl

+39
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,42 @@ function generate_swiss_roll(; n_samples::Int = 100,
361361

362362
return convert(features, labels)
363363
end
364+
365+
"""
366+
function generate_gaussian_quantiles(; mean::Array{<:Union{Number, Nothing}, 1} = [nothing],
367+
cov::Float64 = 1,
368+
n_samples::Int = 100,
369+
n_features::Int = 2,
370+
n_classes::Int = 3,
371+
shuffle::Bool = true,
372+
random_state::Union{Int, Nothing} = nothing)
373+
374+
Generate isotropic Gaussian and label samples by quantile.
375+
#Arguments
376+
- `mean::Array{<:Union{Number, Nothing}, 1} = [nothing]`: The mean of the multi-dimensional normal distribution. If None then use the origin (0, 0, …).
377+
- `cov::Float64 = 1`: The covariance matrix will be this value times the unit matrix.
378+
- `n_samples::Int = 100`: The total number of points equally divided among classes.
379+
- `n_features::Int = 2`: The number of features for each sample.
380+
- `n_classes::Int = 3`: The number of classes.
381+
- `shuffle::Bool = true`: Shuffle the samples.
382+
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls. See Glossary.
383+
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_gaussian_quantiles.html)
384+
"""
385+
function generate_gaussian_quantiles(; mean::Union{Array{<:Number, 1}, Nothing} = nothing,
386+
cov::Float64 = 1.0,
387+
n_samples::Int = 100,
388+
n_features::Int = 2,
389+
n_classes::Int = 3,
390+
shuffle::Bool = true,
391+
random_state::Union{Int, Nothing} = nothing)
392+
393+
(features, labels) = datasets.make_gaussian_quantiles(mean = mean,
394+
cov = cov,
395+
n_samples = n_samples,
396+
n_features = n_features,
397+
n_classes = n_classes,
398+
shuffle = shuffle,
399+
random_state = random_state)
400+
401+
return convert(features, labels)
402+
end

test/runtests.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ using Test
4747
@test size(data)[1] == samples
4848
@test size(data)[2] == features + 1
4949

50-
@test size(data)[1] == samples
51-
@test size(data)[2] == features + 1
52-
5350
data = SyntheticDatasets.generate_friedman1(n_samples = samples,
5451
n_features = features)
5552

@@ -74,13 +71,20 @@ using Test
7471

7572
@test size(data)[1] == samples
7673
@test size(data)[2] == features
77-
74+
7875
data = SyntheticDatasets.generate_swiss_roll(n_samples =samples,
7976
noise = 2.2,
8077
random_state = 5)
8178

8279
@test size(data)[1] == samples
8380
@test size(data)[2] == 4
81+
82+
data = SyntheticDatasets.generate_gaussian_quantiles(n_samples = samples,
83+
n_features = features,
84+
random_state = 5)
85+
86+
@test size(data)[1] == samples
87+
@test size(data)[2] == features + 1
8488
end
8589

8690
@testset "Matlab Generators" begin

0 commit comments

Comments
 (0)