Skip to content

Commit e7f15d5

Browse files
authored
Merge pull request #11 from ATISLabs/feature/generate_regression_function
[#1] - Feature/generate regression function
2 parents 43d637b + 0eaaf7d commit e7f15d5

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Dataset | Title
3030
make_blobs | Generate isotropic Gaussian blobs for clustering. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html)
3131
make_moons | Make two interleaving half circles | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html)
3232
make_s_curve | Generate an S curve dataset. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_s_curve.html)
33+
make_regression | Generate a random regression problem. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html])
3334

3435
**Disclaimer**: SyntheticDatasets.jl borrows code and documentation from
3536
[scikit-learn](https://scikit-learn.org/stable/modules/classes.html#samples-generator) in the dataset module, but *it is not an official part

src/sklearn.jl

+57
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,61 @@ function generate_s_curve(; n_samples::Int = 100,
8383
random_state = random_state)
8484

8585
return convert(features, labels)
86+
end
87+
88+
"""
89+
generate_regression(; n_samples::Int = 100,
90+
n_features::Int = 100,
91+
n_informative::Int = 10,
92+
n_targets::Int = 1,
93+
bias::Float64 = 0.0,
94+
effective_rank::Union{Int, Nothing} = nothing,
95+
tail_strength::Float64 = 0.5,
96+
noise::Float64 = 0.0,
97+
shuffle::Bool = true,
98+
coef::Bool = false,
99+
random_state::Union{Int, Nothing}= nothing)
100+
Generate a random regression problem. Sklearn interface to make_regression.
101+
# Arguments
102+
- `n_samples::Int = 100`: The number of samples.
103+
- `n_features::Int = 2`: The number of features.
104+
- `n_informative::Int = 10`: The number of informative features, i.e., the number of features used to build the linear model used to generate the output.
105+
- `n_targets::Int = 1`: The number of regression targets, i.e., the dimension of the y output vector associated with a sample. By default, the output is a scalar.
106+
- `bias::Float = 0.0`: The bias term in the underlying linear model.
107+
- `effective_rank::Union{Int, Nothing} = nothing`: If not `nothing`, the approximate number of singular vectors required to explain most of the input data by linear combinations. Using this kind of singular spectrum in the input allows the generator to reproduce the correlations often observed in practice. If `nothing`, the input set is well conditioned, centered and gaussian with unit variance.
108+
- `tail_strength::Float = 0.5`: The relative importance of the fat noisy tail of the singular values profile if effective_rank is not None.
109+
- `noise::Union{Nothing, Float64} = nothing`: Standard deviation of Gaussian noise added to the data.
110+
- `shuffle::Bool = true`: Shuffle the samples and the features.
111+
- `coef::Bool = false`: If `true`, the coefficients of the underlying linear model are returned.
112+
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset shuffling and noise.
113+
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html)
114+
"""
115+
function generate_regression(; n_samples::Int = 100,
116+
n_features::Int = 100,
117+
n_informative::Int = 10,
118+
n_targets::Int = 1,
119+
bias::Float64 = 0.0,
120+
effective_rank::Union{Int, Nothing} = nothing,
121+
tail_strength::Float64 = 0.5,
122+
noise::Float64 = 0.0,
123+
shuffle::Bool = true,
124+
coef::Bool = false,
125+
random_state::Union{Int, Nothing}= nothing)
126+
127+
128+
(features, labels) = datasets.make_regression( n_samples = n_samples,
129+
n_features = n_features,
130+
n_informative = n_informative,
131+
n_targets = n_targets,
132+
bias = bias,
133+
effective_rank = effective_rank,
134+
tail_strength = tail_strength,
135+
noise = noise,
136+
shuffle = shuffle,
137+
coef = coef,
138+
random_state = random_state)
139+
140+
141+
return convert(features, labels)
142+
86143
end

test/runtests.jl

+11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using Test
44

55
@testset "SkLearn Generators" begin
66
samples = 20000
7+
features = 20
8+
79
data = SyntheticDatasets.generate_blobs(centers = [-1 1;-0.5 0.75],
810
cluster_std = 0.225,
911
n_samples = 20000,
@@ -25,4 +27,13 @@ using Test
2527
@test size(data)[1] == samples
2628
@test size(data)[2] == 4
2729

30+
31+
data = SyntheticDatasets.generate_regression(n_samples = samples,
32+
n_features = features,
33+
noise = 2.2,
34+
random_state = 5)
35+
36+
@test size(data)[1] == samples
37+
@test size(data)[2] == features + 1
38+
2839
end

0 commit comments

Comments
 (0)