Skip to content

Commit 74e61a9

Browse files
Merge pull request #292 from SciML/fm/init_scaling
feat: increase specificity for scaled_rand init
2 parents 39ca9bf + 2de99aa commit 74e61a9

File tree

1 file changed

+104
-45
lines changed

1 file changed

+104
-45
lines changed

src/esn/esn_inits.jl

Lines changed: 104 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ a range defined by `scaling`.
1717
# Keyword arguments
1818
1919
- `scaling`: A scaling factor to define the range of the uniform distribution.
20-
The matrix elements will be randomly chosen from the
21-
range `[-scaling, scaling]`. Defaults to `0.1`.
20+
The factor can be passed in three different ways:
21+
22+
+ A single number. In this case, the matrix elements will be randomly
23+
chosen from the range `[-scaling, scaling]`. Default option, with
24+
a the scaling value set to `0.1`.
25+
+ A tuple `(lower, upper)`. The values define the range of the distribution.
26+
+ A vector. In this case, the columns will be scaled individually by the
27+
entries of the vector. The entries can be numbers or tuples, which will mirror
28+
the behavior described above.
2229
2330
# Examples
2431
@@ -33,16 +40,68 @@ julia> res_input = scaled_rand(8, 3)
3340
0.0944272 0.0679244 0.0148647
3441
-0.0799005 -0.0891089 -0.0444782
3542
-0.0970182 0.0934286 0.03553
43+
44+
julia> tt = scaled_rand(5, 3, scaling = (0.1, 0.15))
45+
5×3 Matrix{Float32}:
46+
0.13631 0.110929 0.116177
47+
0.116299 0.136038 0.119713
48+
0.11535 0.144712 0.110029
49+
0.127453 0.12657 0.147656
50+
0.139446 0.117656 0.104712
51+
```
52+
53+
Example with vector:
54+
55+
```jldoctest
56+
julia> tt = scaled_rand(5, 3, scaling = [0.1, 0.2, 0.3])
57+
5×3 Matrix{Float32}:
58+
0.0452399 -0.112565 -0.105874
59+
-0.0348047 0.0883044 -0.0634468
60+
-0.0386004 0.157698 -0.179648
61+
0.00981022 0.012559 0.271875
62+
0.0577838 -0.0587553 -0.243451
63+
64+
julia> tt = scaled_rand(5, 3, scaling = [(0.1, 0.2), (-0.2, -0.1), (0.3, 0.5)])
65+
5×3 Matrix{Float32}:
66+
0.17262 -0.178141 0.364709
67+
0.132598 -0.127924 0.378851
68+
0.1307 -0.110575 0.340117
69+
0.154905 -0.14686 0.490625
70+
0.178892 -0.164689 0.31885
3671
```
3772
"""
3873
function scaled_rand(rng::AbstractRNG, ::Type{T}, dims::Integer...;
39-
scaling::Number = T(0.1)) where {T <: Number}
74+
scaling::Union{Number, Tuple, Vector} = T(0.1)) where {T <: Number}
4075
res_size, in_size = dims
41-
layer_matrix = (DeviceAgnostic.rand(rng, T, res_size, in_size) .- T(0.5)) .*
42-
(T(2) * T(scaling))
76+
layer_matrix = DeviceAgnostic.rand(rng, T, res_size, in_size)
77+
apply_scale!(layer_matrix, scaling, T)
4378
return layer_matrix
4479
end
4580

81+
function apply_scale!(input_matrix, scaling::Number, ::Type{T}) where {T}
82+
@. input_matrix = (input_matrix - T(0.5)) * (T(2) * T(scaling))
83+
return input_matrix
84+
end
85+
86+
function apply_scale!(input_matrix,
87+
scaling::Tuple{<:Number, <:Number}, ::Type{T}) where {T}
88+
lower, upper = T(scaling[1]), T(scaling[2])
89+
@assert lower<upper "lower < upper required"
90+
scale = upper - lower
91+
@. input_matrix = input_matrix * scale + lower
92+
return input_matrix
93+
end
94+
95+
function apply_scale!(input_matrix,
96+
scaling::AbstractVector, ::Type{T}) where {T <: Number}
97+
ncols = size(input_matrix, 2)
98+
@assert length(scaling)==ncols "need one scaling per column"
99+
for (idx, col) in enumerate(eachcol(input_matrix))
100+
apply_scale!(col, scaling[idx], T)
101+
end
102+
return input_matrix
103+
end
104+
46105
"""
47106
weighted_init([rng], [T], dims...;
48107
scaling=0.1, return_sparse=false)
@@ -146,11 +205,11 @@ warning.
146205
```jldoctest
147206
julia> res_input = weighted_minimal(8, 3)
148207
┌ Warning: Reservoir size has changed!
149-
150-
│ Computed reservoir size (6) does not equal the provided reservoir size (8).
151-
152-
│ Using computed value (6). Make sure to modify the reservoir initializer accordingly.
153-
208+
209+
│ Computed reservoir size (6) does not equal the provided reservoir size (8).
210+
211+
│ Using computed value (6). Make sure to modify the reservoir initializer accordingly.
212+
154213
└ @ ReservoirComputing ~/.julia/dev/ReservoirComputing/src/esn/esn_inits.jl:159
155214
6×3 Matrix{Float32}:
156215
0.1 0.0 0.0
@@ -370,7 +429,7 @@ using a sine function and subsequent rows are iteratively generated
370429
via the Chebyshev mapping. The first row is defined as:
371430
372431
```math
373-
W[1, j] = \text{amplitude} \cdot \sin(j \cdot \pi / (\text{sine_divisor}
432+
W[1, j] = \text{amplitude} \cdot \sin(j \cdot \pi / (\text{sine_divisor}
374433
\cdot \text{n_cols}))
375434
```
376435
@@ -448,7 +507,7 @@ Generate an input weight matrix using a logistic mapping [Wang2022](@cite)
448507
The first row is initialized using a sine function:
449508
450509
```math
451-
W[1, j] = \text{amplitude} \cdot \sin(j \cdot \pi /
510+
W[1, j] = \text{amplitude} \cdot \sin(j \cdot \pi /
452511
(\text{sine_divisor} \cdot in_size))
453512
```
454513
@@ -527,7 +586,7 @@ as follows:
527586
- The first element of the chain is initialized using a sine function:
528587
529588
```math
530-
W[1,j] = \text{amplitude} \cdot \sin( (j \cdot \pi) /
589+
W[1,j] = \text{amplitude} \cdot \sin( (j \cdot \pi) /
531590
(\text{factor} \cdot \text{n} \cdot \text{sine_divisor}) )
532591
```
533592
where `j` is the index corresponding to the input and `n` is the number of inputs.
@@ -540,7 +599,7 @@ as follows:
540599
541600
The resulting matrix has dimensions `(factor * in_size) x in_size`, where
542601
`in_size` corresponds to the number of columns provided in `dims`.
543-
If the provided number of rows does not match `factor * in_size`
602+
If the provided number of rows does not match `factor * in_size`
544603
the number of rows is overridden.
545604
546605
# Arguments
@@ -576,15 +635,15 @@ julia> modified_lm(20, 10; factor=2)
576635
577636
julia> modified_lm(12, 4; factor=3)
578637
12×4 SparseArrays.SparseMatrixCSC{Float32, Int64} with 9 stored entries:
579-
⋅ ⋅ ⋅ ⋅
580-
⋅ ⋅ ⋅ ⋅
581-
⋅ ⋅ ⋅ ⋅
582-
⋅ 0.0133075 ⋅ ⋅
583-
⋅ 0.0308564 ⋅ ⋅
584-
⋅ 0.070275 ⋅ ⋅
585-
⋅ ⋅ 0.0265887 ⋅
586-
⋅ ⋅ 0.0608222 ⋅
587-
⋅ ⋅ 0.134239 ⋅
638+
⋅ ⋅ ⋅ ⋅
639+
⋅ ⋅ ⋅ ⋅
640+
⋅ ⋅ ⋅ ⋅
641+
⋅ 0.0133075 ⋅ ⋅
642+
⋅ 0.0308564 ⋅ ⋅
643+
⋅ 0.070275 ⋅ ⋅
644+
⋅ ⋅ 0.0265887 ⋅
645+
⋅ ⋅ 0.0608222 ⋅
646+
⋅ ⋅ 0.134239 ⋅
588647
⋅ ⋅ ⋅ 0.0398177
589648
⋅ ⋅ ⋅ 0.0898457
590649
⋅ ⋅ ⋅ 0.192168
@@ -671,7 +730,7 @@ function rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...;
671730
end
672731

673732
"""
674-
pseudo_svd([rng], [T], dims...;
733+
pseudo_svd([rng], [T], dims...;
675734
max_value=1.0, sparsity=0.1, sorted=true, reverse_sort=false,
676735
return_sparse=false)
677736
@@ -821,15 +880,15 @@ closest valid order is used.
821880
822881
```jldoctest
823882
julia> res_matrix = chaotic_init(8, 8)
824-
┌ Warning:
825-
883+
┌ Warning:
884+
826885
│ Adjusting reservoir matrix order:
827886
│ from 8 (requested) to 4
828-
│ based on computed bit precision = 1.
829-
887+
│ based on computed bit precision = 1.
888+
830889
└ @ ReservoirComputing ~/.julia/dev/ReservoirComputing/src/esn/esn_inits.jl:805
831890
4×4 SparseArrays.SparseMatrixCSC{Float32, Int64} with 6 stored entries:
832-
⋅ -0.600945 ⋅ ⋅
891+
⋅ -0.600945 ⋅ ⋅
833892
⋅ ⋅ 0.132667 2.21354
834893
⋅ -2.60383 ⋅ -2.90391
835894
-0.578156 ⋅ ⋅ ⋅
@@ -1148,7 +1207,7 @@ function delay_line_backward(rng::AbstractRNG, ::Type{T}, dims::Integer...;
11481207
end
11491208

11501209
"""
1151-
cycle_jumps([rng], [T], dims...;
1210+
cycle_jumps([rng], [T], dims...;
11521211
cycle_weight=0.1, jump_weight=0.1, jump_size=3, return_sparse=false,
11531212
cycle_kwargs=(), jump_kwargs=())
11541213
@@ -1234,7 +1293,7 @@ function cycle_jumps(rng::AbstractRNG, ::Type{T}, dims::Integer...;
12341293
end
12351294

12361295
"""
1237-
simple_cycle([rng], [T], dims...;
1296+
simple_cycle([rng], [T], dims...;
12381297
weight=0.1, return_sparse=false,
12391298
kwargs...)
12401299
@@ -1303,7 +1362,7 @@ function simple_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
13031362
end
13041363

13051364
"""
1306-
double_cycle([rng], [T], dims...;
1365+
double_cycle([rng], [T], dims...;
13071366
cycle_weight=0.1, second_cycle_weight=0.1,
13081367
return_sparse=false)
13091368
@@ -1358,7 +1417,7 @@ function double_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
13581417
end
13591418

13601419
"""
1361-
true_double_cycle([rng], [T], dims...;
1420+
true_double_cycle([rng], [T], dims...;
13621421
cycle_weight=0.1, second_cycle_weight=0.1,
13631422
return_sparse=false)
13641423
@@ -1427,7 +1486,7 @@ function true_double_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
14271486
end
14281487

14291488
@doc raw"""
1430-
selfloop_cycle([rng], [T], dims...;
1489+
selfloop_cycle([rng], [T], dims...;
14311490
cycle_weight=0.1, selfloop_weight=0.1,
14321491
return_sparse=false, kwargs...)
14331492
@@ -1518,7 +1577,7 @@ function selfloop_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
15181577
end
15191578

15201579
@doc raw"""
1521-
selfloop_feedback_cycle([rng], [T], dims...;
1580+
selfloop_feedback_cycle([rng], [T], dims...;
15221581
cycle_weight=0.1, selfloop_weight=0.1,
15231582
return_sparse=false)
15241583
@@ -1601,7 +1660,7 @@ function selfloop_feedback_cycle(rng::AbstractRNG, ::Type{T}, dims::Integer...;
16011660
end
16021661

16031662
@doc raw"""
1604-
selfloop_delayline_backward([rng], [T], dims...;
1663+
selfloop_delayline_backward([rng], [T], dims...;
16051664
weight=0.1, selfloop_weight=0.1, fb_weight=0.1,
16061665
fb_shift=2, return_sparse=false, fb_kwargs=(),
16071666
selfloop_kwargs=(), delay_kwargs=())
@@ -1707,7 +1766,7 @@ function selfloop_delayline_backward(rng::AbstractRNG, ::Type{T}, dims::Integer.
17071766
end
17081767

17091768
@doc raw"""
1710-
selfloop_forward_connection([rng], [T], dims...;
1769+
selfloop_forward_connection([rng], [T], dims...;
17111770
weight=0.1, selfloop_weight=0.1,
17121771
return_sparse=false, selfloop_kwargs=(),
17131772
delay_kwargs=())
@@ -1749,7 +1808,7 @@ W_{i,j} =
17491808
Default is 0.1.
17501809
- `return_sparse`: flag for returning a `sparse` matrix.
17511810
Default is `false`.
1752-
- `delay_kwargs` and `selfloop_kwargs`: named tuples that control the kwargs for the
1811+
- `delay_kwargs` and `selfloop_kwargs`: named tuples that control the kwargs for the
17531812
delay line weight and self loop weights respectively. The kwargs are as follows:
17541813
+ `sampling_type`: Sampling that decides the distribution of `weight` negative numbers.
17551814
If set to `:no_sample` the sign is unchanged. If set to `:bernoulli_sample!` then each
@@ -1801,7 +1860,7 @@ function selfloop_forward_connection(rng::AbstractRNG, ::Type{T}, dims::Integer.
18011860
end
18021861

18031862
@doc raw"""
1804-
forward_connection([rng], [T], dims...;
1863+
forward_connection([rng], [T], dims...;
18051864
weight=0.1, selfloop_weight=0.1,
18061865
return_sparse=false)
18071866
@@ -1887,8 +1946,8 @@ end
18871946
return_sparse=false)
18881947
18891948
Creates a block‐diagonal matrix consisting of square blocks of size
1890-
`block_size` along the main diagonal [Ma2023](@cite).
1891-
Each block may be filled with
1949+
`block_size` along the main diagonal [Ma2023](@cite).
1950+
Each block may be filled with
18921951
- a single scalar
18931952
- a vector of per‐block weights (length = number of blocks)
18941953
@@ -1897,21 +1956,21 @@ Each block may be filled with
18971956
```math
18981957
W_{i,j} =
18991958
\begin{cases}
1900-
w_b, & \text{if }\left\lfloor\frac{i-1}{s}\right\rfloor = \left\lfloor\frac{j-1}{s}\right\rfloor = b,\;
1959+
w_b, & \text{if }\left\lfloor\frac{i-1}{s}\right\rfloor = \left\lfloor\frac{j-1}{s}\right\rfloor = b,\;
19011960
s = \text{block\_size},\; b=0,\dots,nb-1, \\
19021961
0, & \text{otherwise,}
19031962
\end{cases}
19041963
```
19051964
19061965
# Arguments
19071966
1908-
- `rng`: Random number generator. Default is `Utils.default_rng()`.
1909-
- `T`: Element type of the matrix. Default is `Float32`.
1967+
- `rng`: Random number generator. Default is `Utils.default_rng()`.
1968+
- `T`: Element type of the matrix. Default is `Float32`.
19101969
- `dims`: Dimensions of the output matrix (must be two-dimensional).
19111970
19121971
# Keyword arguments
19131972
1914-
- `weight`:
1973+
- `weight`:
19151974
- scalar: every block is filled with that value
19161975
- vector: length = number of blocks, one constant per block
19171976
Default is `1.0`.

0 commit comments

Comments
 (0)