Skip to content

Commit aa9696d

Browse files
authored
Merge pull request #13 from ATISLabs/feature/generate_classification_function
[#1] - Feature/generate classification function
2 parents e7f15d5 + 11006b4 commit aa9696d

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ make_blobs | Generate isotropic Gaussian blobs for clustering.
3131
make_moons | Make two interleaving half circles | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html)
3232
make_s_curve | Generate an S curve dataset. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_s_curve.html)
3333
make_regression | Generate a random regression problem. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html])
34+
make_classification | Generate a random n-class classification problem. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html])
3435

3536
**Disclaimer**: SyntheticDatasets.jl borrows code and documentation from
3637
[scikit-learn](https://scikit-learn.org/stable/modules/classes.html#samples-generator) in the dataset module, but *it is not an official part

Diff for: src/sklearn.jl

+71
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,75 @@ function generate_regression(; n_samples::Int = 100,
140140

141141
return convert(features, labels)
142142

143+
end
144+
145+
"""
146+
function generate_classification(; n_samples::Int = 100,
147+
n_features::Int = 20,
148+
n_informative::Int = 2,
149+
n_redundant::Int = 2,
150+
n_repeated::Int = 0,
151+
n_classes::Int = 2,
152+
n_clusters_per_class::Int = 2,
153+
weights::Union{Nothing, Array{Float64,1}} = nothing,
154+
flip_y::Float64 = 0.01,
155+
class_sep::Float64 = 1.0,
156+
hypercube::Bool = true,
157+
shift::Union{Nothing, Array{Float64,1}} = 0.0,
158+
scale::Union{Nothing, Array{Float64,1}} = 1.0,
159+
shuffle::Bool = true,
160+
random_state::Union{Int, Nothing} = nothing)
161+
Generate a random n-class classification problem. Sklearn interface to make_classification.
162+
#Arguments
163+
- `n_samples::Int = 100`: The number of samples.
164+
- `n_features::Int = 20`: The total number of features. These comprise `n_informative` informative features, `n_redundant` redundant features, `n_repeated` duplicated features and `n_features-n_informative-n_redundant-n_repeated` useless features drawn at random.
165+
- `n_informative::Int = 2`: The number of informative features. Each class is composed of a number of gaussian clusters each located around the vertices of a hypercube in a subspace of dimension `n_informative`. For each cluster, informative features are drawn independently from N(0, 1) and then randomly linearly combined within each cluster in order to add covariance. The clusters are then placed on the vertices of the hypercube.
166+
- `n_redundant::Int = 2`: The number of redundant features. These features are generated as random linear combinations of the informative features.
167+
- `n_repeated::Int = 0`: The number of duplicated features, drawn randomly from the informative and the redundant features.
168+
- `n_classes::Int = 2`: The number of classes (or labels) of the classification problem.
169+
- `n_clusters_per_class::Int = 2`: The number of clusters per class.
170+
- `weights::Union{Nothing, Array{Float64,1}} = nothing`:
171+
- `flip_y::Float64 = 0.01`: The fraction of samples whose class is assigned randomly. Larger values introduce noise in the labels and make the classification task harder. Note that the default setting flip_y > 0 might lead to less than n_classes in y in some cases.
172+
- `class_sep::Float64 = 1.0`: The factor multiplying the hypercube size. Larger values spread out the clusters/classes and make the classification task easier.
173+
- `hypercube::Bool = true`: If True, the clusters are put on the vertices of a hypercube. If False, the clusters are put on the vertices of a random polytope.
174+
- `shift::Union{Nothing, Array{Float64,1}} = 0.0`: Shift features by the specified value. If None, then features are shifted by a random value drawn in [-class_sep, class_sep].
175+
- `scale::Union{Nothing, Array{Float64,1}} = 1.0`: Multiply features by the specified value. If None, then features are scaled by a random value drawn in [1, 100]. Note that scaling happens after shifting.
176+
- `shuffle::Bool = true`: Shuffle the samples and the features.
177+
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls. See Glossary.
178+
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html)
179+
"""
180+
function generate_classification(; n_samples::Int = 100,
181+
n_features::Int = 20,
182+
n_informative::Int = 2,
183+
n_redundant::Int = 2,
184+
n_repeated::Int = 0,
185+
n_classes::Int = 2,
186+
n_clusters_per_class::Int = 2,
187+
weights::Union{Nothing, Array{Float64,1}} = nothing,
188+
flip_y::Float64 = 0.01,
189+
class_sep::Float64 = 1.0,
190+
hypercube::Bool = true,
191+
shift::Union{Nothing, Float64, Array{Float64,1}} = 0.0,
192+
scale::Union{Nothing, Float64, Array{Float64,1}} = 1.0,
193+
shuffle::Bool = true,
194+
random_state::Union{Int, Nothing} = nothing)
195+
196+
197+
(features, labels) = datasets.make_classification( n_samples = n_samples,
198+
n_features = n_features,
199+
n_informative = n_informative,
200+
n_redundant = n_redundant,
201+
n_repeated = n_repeated,
202+
n_classes = n_classes,
203+
n_clusters_per_class = n_clusters_per_class,
204+
weights = weights,
205+
flip_y = flip_y,
206+
class_sep = class_sep,
207+
hypercube = hypercube,
208+
shift = shift,
209+
scale = scale,
210+
shuffle = shuffle,
211+
random_state = random_state)
212+
213+
return convert(features, labels)
143214
end

Diff for: test/runtests.jl

+8
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,12 @@ using Test
3636
@test size(data)[1] == samples
3737
@test size(data)[2] == features + 1
3838

39+
data = SyntheticDatasets.generate_classification(n_samples = samples,
40+
n_features = features,
41+
n_classes = 1)
42+
43+
44+
@test size(data)[1] == samples
45+
@test size(data)[2] == features + 1
46+
3947
end

0 commit comments

Comments
 (0)