Skip to content

Commit f181872

Browse files
authored
Merge pull request #18 from ATISLabs/feature/generate_low_rank_matrix_2
Feature/Generate_Low_Rank_Matrix_2
2 parents 34ff75b + 22ae2c6 commit f181872

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

src/sklearn.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,35 @@ function generate_classification(; n_samples::Int = 100,
240240
return convert(features, labels)
241241
end
242242

243+
"""
244+
function generate_low_rank_matrix(; n_samples::Int =100,
245+
n_features::Int =100,
246+
effective_rank::Int =10,
247+
tail_strength::Float64 =0.5,
248+
random_state::Union{Int, Nothing} = nothing)
249+
Generate a mostly low rank matrix with bell-shaped singular values
250+
#Arguments
251+
- `n_samples::Int = 100`: The number of samples.
252+
- `n_features::Int = 20`: The total number of features. These comprise `n_informative` informative features, `n_redundant` redundant features, `n_repeated` duplicated features and `n_features-n_informative-n_redundant-n_repeated` useless features drawn at random.
253+
- `effective_rank::Int = 10`: The approximate number of singular vectors required to explain most of the data by linear combinations.
254+
- `tail_strength::Float64 = 0.5`: The relative importance of the fat noisy tail of the singular values profile.
255+
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls. See Glossary.
256+
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_low_rank_matrix.html)
257+
"""
258+
function generate_low_rank_matrix(; n_samples::Int = 100,
259+
n_features::Int = 100,
260+
effective_rank::Int = 10,
261+
tail_strength::Float64 = 0.5,
262+
random_state::Union{Int, Nothing} = nothing)
263+
264+
features = datasets.make_low_rank_matrix(n_samples = n_samples,
265+
n_features = n_features,
266+
effective_rank = effective_rank,
267+
tail_strength = tail_strength,
268+
random_state = random_state)
269+
return features
270+
end
271+
243272
"""
244273
function generate_swiss_roll(; n_samples::Int = 100,
245274
noise::Float64 = 0.0,

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ using Test
4747
@test size(data)[1] == samples
4848
@test size(data)[2] == features + 1
4949

50+
data = SyntheticDatasets.generate_low_rank_matrix(n_samples = samples,
51+
n_features = features,
52+
effective_rank = 10,
53+
tail_strength = 0.5,
54+
random_state = 5)
55+
56+
@test size(data)[1] == samples
57+
@test size(data)[2] == features
58+
5059
data = SyntheticDatasets.generate_swiss_roll(n_samples =samples,
5160
noise = 2.2,
5261
random_state = 5)

0 commit comments

Comments
 (0)