Skip to content

Commit 5e40c76

Browse files
committed
Make LSTM config consisternt with GenerateTxtModel.java, clean up static final Strings in MelodyStrings.java so they're easier to modify
1 parent bb104c7 commit 5e40c76

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/wip/advanced/modelling/melodl4j/MelodyModelingExample.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.nd4j.linalg.api.ndarray.INDArray;
3737
import org.nd4j.linalg.dataset.DataSet;
3838
import org.nd4j.linalg.factory.Nd4j;
39+
import org.nd4j.linalg.learning.config.Adam;
3940
import org.nd4j.linalg.learning.config.RmsProp;
4041
import org.nd4j.linalg.lossfunctions.LossFunctions;
4142

@@ -120,9 +121,10 @@ public static void main(String[] args) throws Exception {
120121

121122
//Set up network configuration:
122123
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
123-
.updater(new RmsProp(0.1))
124-
.seed(12345)
125-
.l2(0.001)
124+
//.updater(new RmsProp(0.1))
125+
.updater(new Adam(0.005))
126+
.seed(System.currentTimeMillis()) // So each run generates new melodies
127+
.l2(0.0001)
126128
.weightInit(WeightInit.XAVIER)
127129
.list()
128130
.layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
@@ -136,7 +138,6 @@ public static void main(String[] args) throws Exception {
136138
.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)
137139
.build();
138140

139-
140141
learn(miniBatchSize, exampleLength, numEpochs, generateSamplesEveryNMinibatches, nSamplesToGenerate, nCharactersToSample, generationInitialization, rng, startTime, iter, conf);
141142
}
142143

dl4j-examples/src/main/java/org/deeplearning4j/examples/wip/advanced/modelling/melodl4j/MelodyStrings.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,27 +35,32 @@ public class MelodyStrings {
3535
public static final char lowestPitchGapChar = 'A';
3636
public static final char REST_CHAR = ' ';
3737

38-
// We allow pitch gaps between -12 and +12, inclusive.
39-
// If you want to change the allowed gap, you will have to change the characters in PITCH_GAP_CHARS_NEGATIVE and PITCH_GAP_CHARS_POSITIVE
38+
// As written now, it allows pitch gaps between -12 and +12, inclusive.
39+
// If you want to change the allowed gap, you will have to change the characters in PITCH_GAP_CHARS_POSITIVE
40+
// and PITCH_GAP_CHARS_NEGATIVE
41+
42+
// There are thirteen chars in pitchGapCharsPositive because the first one ('M') indicates a zero pitch gap.
43+
// "M" indicates delta=0. "N" indicates delta=1, 'O' indicates delta=2, etc.
44+
public static final String PITCH_GAP_CHARS_POSITIVE = "MNOPQRSTUVWXY";
45+
46+
public static int MAX_POSITIVE_PITCH_GAP = PITCH_GAP_CHARS_POSITIVE.length()-1; // -1 because the first char is for zero.
47+
public static final char ZERO_PITCH_DELTA_CHAR = PITCH_GAP_CHARS_POSITIVE.charAt(0);
48+
public static final char MAX_POSITIVE_PITCH_DELTA_CHAR = PITCH_GAP_CHARS_POSITIVE.charAt(PITCH_GAP_CHARS_POSITIVE.length()-1);
4049

41-
public static int MAX_POSITIVE_PITCH_GAP = 12;
4250
// The order of characters below is intended to simplify the learning of pitch gaps, because increasing
43-
// characters correspond to increasing pitch gaps. The string must end with 'A'.
44-
public static final String PITCH_GAP_CHARS_NEGATIVE = "LKJIHGFEDCBA"; // "L" indicates delta=-1. "K" indicates -2,...
51+
// characters correspond to increasing pitch gaps. (Whether this convention simplifies learning pitches
52+
// depends on which learning algorithm is used.) The string must end with 'A'.
53+
// "L" indicates delta=-1. "K" indicates delta = -2, 'J' indicates delta = -3, etc.
54+
public static final String PITCH_GAP_CHARS_NEGATIVE = "LKJIHGFEDCBA";
4555

4656
public static final char FIRST_PITCH_CHAR_NEGATIVE = PITCH_GAP_CHARS_NEGATIVE.charAt(0);
4757

48-
// There are thirteen chars in pitchGapCharsPositive because the first one ('M') indicates a zero pitch gap.
49-
public static final String PITCH_GAP_CHARS_POSITIVE = "MNOPQRSTUVWXY"; // "M" indicates delta=0. "N" indicates 1, ...
50-
public static final char ZERO_PITCH_DELTA_CHAR = 'M';
51-
public static final char MAX_POSITIVE_PITCH_DELTA_CHAR = 'Y';
52-
5358
// durationDeltaParts determines how short durations can get. The shortest duration is 1/durationDeltaParts
5459
// as long as the average note length in the piece.
5560
public static int durationDeltaParts = 8;
5661
public static final String DURATION_CHARS = "]^_`abcdefghijklmnopqrstuvwxyz{|"; // 32 divisions, in ASCII order
5762

58-
public static final char FIRST_DURATION_CHAR = ']';
63+
public static final char FIRST_DURATION_CHAR = DURATION_CHARS.charAt(0);
5964
public static final String allValidCharacters = getValidCharacters();
6065
// 13+13+1+32 = 59 possible characters.
6166
// ']' indicates the smallest pitch duration allowed (typically a 1/32 note or so).

0 commit comments

Comments
 (0)