Skip to content

Commit e2716ea

Browse files
authored
Merge pull request #10 from ATISLabs/feature/adding_generate_s_curve_function
[#1] - Feature/adding generate s curve function
2 parents f9e6b25 + 57cc8ac commit e2716ea

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Dataset | Title
2929
----------------|------------------------------------------------------------------------|--------------------------------------------------
3030
make_blobs | Generate isotropic Gaussian blobs for clustering. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html)
3131
make_moons | Make two interleaving half circles | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html)
32-
32+
make_s_curve | Generate an S curve dataset. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_s_curve.html)
3333

3434
[travis-img]: https://travis-ci.com/ATISLabs/SyntheticDatasets.jl.svg?branch=master
3535
[travis-url]: https://travis-ci.com/ATISLabs/SyntheticDatasets.jl

src/SyntheticDatasets.jl

+23-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,29 @@ function generate_blobs(;n_samples::Union{Int, Array{Int, 1}} = 100,
7575
return convert(features, labels)
7676
end
7777

78-
function convert(features::Array{T, 2}, labels::Array{Int, 1})::DataFrame where T <: Number
78+
"""
79+
generate_s_curve(; n_samples::Int = 100,
80+
noise = nothing,
81+
random_state = nothing)::DataFrame
82+
Generate an S curve dataset. Sklearn interface to make_s_curve.
83+
# Arguments
84+
- `n_samples::Int = 100`: The number of sample points on the S curve.
85+
- `noise::Union{Nothing, Float64} = nothing`: Standard deviation of Gaussian noise added to the data.
86+
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls.
87+
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_s_curve.html)
88+
"""
89+
function generate_s_curve(; n_samples::Int = 100,
90+
noise::Float64 = 0.0,
91+
random_state::Union{Int, Nothing} = nothing)::DataFrame
92+
93+
(features, labels) = datasets.make_s_curve( n_samples = n_samples,
94+
noise = noise,
95+
random_state = random_state)
96+
97+
return convert(features, labels)
98+
end
99+
100+
function convert(features::Array{T, 2}, labels::Array{D, 1})::DataFrame where {T <: Number, D <: Number}
79101
df = DataFrame()
80102

81103
for i = 1:size(features)[2]

test/runtests.jl

+8
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,12 @@ using Test
1717

1818
@test size(data)[1] == samples
1919
@test size(data)[2] == 3
20+
21+
data = SyntheticDatasets.generate_s_curve(n_samples = samples,
22+
noise = 2.2,
23+
random_state = 5)
24+
25+
@test size(data)[1] == samples
26+
@test size(data)[2] == 4
27+
2028
end

0 commit comments

Comments
 (0)