1
1
/*******************************************************************************
2
- *
3
- *
4
2
*
5
3
* This program and the accompanying materials are made available under the
6
4
* terms of the Apache License, Version 2.0 which is available at
17
15
* SPDX-License-Identifier: Apache-2.0
18
16
******************************************************************************/
19
17
20
- package org .deeplearning4j .examples .wip . advanced .modelling .melodl4j ;
18
+ package org .deeplearning4j .examples .advanced .modelling . charmodelling .melodl4j ;
21
19
22
20
import org .apache .commons .io .FileUtils ;
23
21
import org .deeplearning4j .examples .advanced .modelling .charmodelling .utils .CharacterIterator ;
31
29
import org .deeplearning4j .nn .weights .WeightInit ;
32
30
import org .deeplearning4j .optimize .listeners .ScoreIterationListener ;
33
31
import org .deeplearning4j .util .ModelSerializer ;
32
+ import org .nd4j .common .util .ArchiveUtils ;
34
33
import org .nd4j .linalg .activations .Activation ;
35
34
import org .nd4j .linalg .api .ndarray .INDArray ;
36
35
import org .nd4j .linalg .dataset .DataSet ;
37
36
import org .nd4j .linalg .factory .Nd4j ;
37
+ import org .nd4j .linalg .learning .config .Adam ;
38
38
import org .nd4j .linalg .learning .config .RmsProp ;
39
39
import org .nd4j .linalg .lossfunctions .LossFunctions ;
40
40
41
+ import javax .sound .midi .InvalidMidiDataException ;
41
42
import java .io .*;
42
43
import java .net .URL ;
43
44
import java .nio .charset .Charset ;
45
+ import java .nio .file .Files ;
46
+ import java .nio .file .Path ;
47
+ import java .text .NumberFormat ;
44
48
import java .util .ArrayList ;
45
49
import java .util .List ;
46
50
import java .util .Random ;
51
+ import java .util .zip .ZipEntry ;
52
+ import java .util .zip .ZipInputStream ;
47
53
48
54
/**
49
55
* LSTM Symbolic melody modelling example, to compose music from symbolic melodies extracted from MIDI.
50
- * Based closely on LSTMCharModellingExample.java.
56
+ * LSTM logic is based closely on LSTMCharModellingExample.java.
51
57
* See the README file in this directory for documentation.
52
58
*
53
59
* @author Alex Black, Donald A. Smith.
54
60
*/
55
61
public class MelodyModelingExample {
56
- final static String inputSymbolicMelodiesFilename = "bach-melodies-input.txt" ;
57
- // Examples: bach-melodies-input.txt, beatles-melodies-input.txt , pop-melodies-input.txt (large)
62
+ // If you want to change the MIDI files used in learning, create a zip file containing your MIDI
63
+ // files and replace the following path. For example, you might use something like:
64
+ //final static String midiFileZipFileUrlPath = "file:d:/music/midi/classical-midi.zip";
65
+ final static String midiFileZipFileUrlPath = "http://waliberals.org/truthsite/music/bach-midi.zip" ;
58
66
59
- final static String tmpDir = System .getProperty ("java.io.tmpdir" );
67
+ // For example "bach-midi.txt"
68
+ final static String inputSymbolicMelodiesFilename = getMelodiesFileNameFromURLPath (midiFileZipFileUrlPath );
60
69
61
- final static String symbolicMelodiesInputFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename ; // Point to melodies created by MidiMelodyExtractor.java
70
+ // Examples: bach-melodies-input.txt, beatles-melodies-input.txt , pop-melodies-input.txt (large)
71
+ final static String tmpDir = System .getProperty ("java.io.tmpdir" );
72
+ final static String inputSymbolicMelodiesFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename ; // Point to melodies created by MidiMelodyExtractor.java
62
73
final static String composedMelodiesOutputFilePath = tmpDir + "/composition.txt" ; // You can listen to these melodies by running PlayMelodyStrings.java against this file.
63
74
64
75
//final static String symbolicMelodiesInputFilePath = "D:/tmp/bach-melodies.txt";
65
76
//final static String composedMelodiesOutputFilePath = tmpDir + "/bach-composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file.
66
-
77
+ final static NumberFormat numberFormat = NumberFormat .getNumberInstance ();
78
+ static {
79
+ numberFormat .setMinimumFractionDigits (1 );
80
+ numberFormat .setMaximumFractionDigits (1 );
81
+ }
67
82
//....
68
83
public static void main (String [] args ) throws Exception {
69
84
String loadNetworkPath = null ; //"/tmp/MelodyModel-bach.zip"; //null;
@@ -73,6 +88,8 @@ public static void main(String[] args) throws Exception {
73
88
generationInitialization = args [1 ];
74
89
}
75
90
91
+ makeMidiStringFileIfNecessary ();
92
+
76
93
int lstmLayerSize = 200 ; //Number of units in each LSTM layer
77
94
int miniBatchSize = 32 ; //Size of mini batch to use when training
78
95
int exampleLength = 500 ; //1000; //Length of each training example sequence to use.
@@ -107,9 +124,10 @@ public static void main(String[] args) throws Exception {
107
124
108
125
//Set up network configuration:
109
126
MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
110
- .updater (new RmsProp (0.1 ))
111
- .seed (12345 )
112
- .l2 (0.001 )
127
+ //.updater(new RmsProp(0.1))
128
+ .updater (new Adam (0.005 ))
129
+ .seed (System .currentTimeMillis ()) // So each run generates new melodies
130
+ .l2 (0.0001 )
113
131
.weightInit (WeightInit .XAVIER )
114
132
.list ()
115
133
.layer (0 , new LSTM .Builder ().nIn (iter .inputColumns ()).nOut (lstmLayerSize )
@@ -123,7 +141,6 @@ public static void main(String[] args) throws Exception {
123
141
.backpropType (BackpropType .TruncatedBPTT ).tBPTTForwardLength (tbpttLength ).tBPTTBackwardLength (tbpttLength )
124
142
.build ();
125
143
126
-
127
144
learn (miniBatchSize , exampleLength , numEpochs , generateSamplesEveryNMinibatches , nSamplesToGenerate , nCharactersToSample , generationInitialization , rng , startTime , iter , conf );
128
145
}
129
146
@@ -154,6 +171,7 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
154
171
// order, so that the best melodies are at the start of the file.
155
172
//Do training, and then generate and print samples from network
156
173
int miniBatchNumber = 0 ;
174
+ long lastTime = System .currentTimeMillis ();
157
175
for (int epoch = 0 ; epoch < numEpochs ; epoch ++) {
158
176
System .out .println ("Starting epoch " + epoch );
159
177
while (iter .hasNext ()) {
@@ -176,12 +194,19 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
176
194
}
177
195
}
178
196
iter .reset (); //Reset iterator for another epoch
197
+ final double secondsForEpoch = 0.001 * (System .currentTimeMillis () - startTime );
198
+ final long now = System .currentTimeMillis ();
179
199
if (melodies .size () > 0 ) {
180
200
String melody = melodies .get (melodies .size () - 1 );
181
201
int seconds = 25 ;
182
202
System .out .println ("\n First " + seconds + " seconds of " + melody );
183
203
PlayMelodyStrings .playMelody (melody , seconds );
184
204
}
205
+ double seconds = 0.001 *(now - lastTime );
206
+ lastTime = now ;
207
+ System .out .println ("\n Epoch " + epoch + " time in seconds: " + numberFormat .format (seconds ));
208
+ // 531.9 for GPU GTX 1070
209
+ // 821.4 for CPU i7-6700K @ 4GHZ
185
210
}
186
211
int indexOfLastPeriod = inputSymbolicMelodiesFilename .lastIndexOf ('.' );
187
212
String saveFileName = inputSymbolicMelodiesFilename .substring (0 , indexOfLastPeriod > 0 ? indexOfLastPeriod : inputSymbolicMelodiesFilename .length ());
@@ -193,42 +218,82 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
193
218
printWriter .println (melodies .get (i ));
194
219
}
195
220
printWriter .close ();
196
- double seconds = 0.001 * (System .currentTimeMillis () - startTime );
197
221
198
- System .out .println ("\n \n Example complete in " + seconds + " seconds" );
199
222
System .exit (0 );
200
223
}
201
224
202
- public static void makeSureFileIsInTmpDir (String filename ) {
225
+ public static File makeSureFileIsInTmpDir (String urlString ) throws IOException {
226
+ final URL url = new URL (urlString );
227
+ final String filename = urlString .substring (1 +urlString .lastIndexOf ("/" ));
203
228
final File f = new File (tmpDir + "/" + filename );
204
- if (!f .exists ()) {
205
- URL url = null ;
206
- try {
207
- url = new URL ("http://truthsite.org/music/" + filename );
208
- FileUtils .copyURLToFile (url , f );
209
- } catch (Exception exc ) {
210
- System .err .println ("Error copying " + url + " to " + f );
211
- throw new RuntimeException (exc );
212
- }
229
+ if (f .exists ()) {
230
+ System .out .println ("Using existing " + f .getAbsolutePath ());
231
+ } else {
232
+ FileUtils .copyURLToFile (url , f );
213
233
if (!f .exists ()) {
214
234
throw new RuntimeException (f .getAbsolutePath () + " does not exist" );
215
235
}
216
236
System .out .println ("File downloaded to " + f .getAbsolutePath ());
217
- } else {
218
- System .out .println ("Using existing text file at " + f .getAbsolutePath ());
219
237
}
238
+ return f ;
220
239
}
221
240
241
+ //https://stackoverflow.com/questions/10633595/java-zip-how-to-unzip-folder
242
+ public static void unzip (File zipFile , File targetDirFile ) throws IOException {
243
+ InputStream is = new FileInputStream (zipFile );
244
+ Path targetDir = targetDirFile .toPath ();
245
+ targetDir = targetDir .toAbsolutePath ();
246
+ try (ZipInputStream zipIn = new ZipInputStream (is )) {
247
+ for (ZipEntry ze ; (ze = zipIn .getNextEntry ()) != null ; ) {
248
+ Path resolvedPath = targetDir .resolve (ze .getName ()).normalize ();
249
+ if (!resolvedPath .startsWith (targetDir )) {
250
+ // see: https://snyk.io/research/zip-slip-vulnerability
251
+ throw new RuntimeException ("Entry with an illegal path: "
252
+ + ze .getName ());
253
+ }
254
+ if (ze .isDirectory ()) {
255
+ Files .createDirectories (resolvedPath );
256
+ } else {
257
+ Files .createDirectories (resolvedPath .getParent ());
258
+ Files .copy (zipIn , resolvedPath );
259
+ }
260
+ }
261
+ }
262
+ is .close ();
263
+ }
264
+ private static void makeMidiStringFileIfNecessary () throws IOException , InvalidMidiDataException {
265
+ final File inputMelodiesFile = new File (inputSymbolicMelodiesFilePath );
266
+ if (inputMelodiesFile .exists () && inputMelodiesFile .length ()>1000 ) {
267
+ System .out .println ("Using existing " + inputSymbolicMelodiesFilePath );
268
+ return ;
269
+ }
270
+ final File midiZipFile = makeSureFileIsInTmpDir (midiFileZipFileUrlPath );
271
+ final String midiZipFileName = midiZipFile .getName ();
272
+ final String midiZipFileNameWithoutSuffix = midiZipFileName .substring (0 ,midiZipFileName .lastIndexOf ("." ));
273
+ final File outputDirectoryFile = new File (tmpDir ,midiZipFileNameWithoutSuffix );
274
+ final String outputDirectoryPath = outputDirectoryFile .getAbsolutePath ();
275
+ if (!outputDirectoryFile .exists ()) {
276
+ outputDirectoryFile .mkdir ();
277
+ }
278
+ if (!outputDirectoryFile .exists () || !outputDirectoryFile .isDirectory ()) {
279
+ throw new IllegalStateException (outputDirectoryFile + " is not a directory or can't be created" );
280
+ }
281
+ final PrintStream printStream = new PrintStream (inputSymbolicMelodiesFilePath );
282
+ System .out .println ("Unzipping " + midiZipFile .getAbsolutePath () + " to " + outputDirectoryPath );
283
+ unzip (midiZipFile , outputDirectoryFile );
284
+ System .out .println ("Extracted " + midiZipFile .getAbsolutePath () + " to " + outputDirectoryPath );
285
+ MidiMelodyExtractor .processDirectoryAndWriteMelodyFile (outputDirectoryFile ,inputMelodiesFile );
286
+ printStream .close ();
287
+ }
222
288
/**
223
289
* Sets up and return a simple DataSetIterator that does vectorization based on the melody sample.
224
290
*
225
291
* @param miniBatchSize Number of text segments in each training mini-batch
226
292
* @param sequenceLength Number of characters in each text segment.
227
293
*/
228
294
public static CharacterIterator getMidiIterator (int miniBatchSize , int sequenceLength ) throws Exception {
229
- makeSureFileIsInTmpDir (inputSymbolicMelodiesFilename );
230
295
final char [] validCharacters = MelodyStrings .allValidCharacters .toCharArray (); //Which characters are allowed? Others will be removed
231
- return new CharacterIterator (symbolicMelodiesInputFilePath , Charset .forName ("UTF-8" ),
296
+ return new CharacterIterator (inputSymbolicMelodiesFilePath , Charset .forName ("UTF-8" ),
232
297
miniBatchSize , sequenceLength , validCharacters , new Random (12345 ), MelodyStrings .COMMENT_STRING );
233
298
}
234
299
@@ -312,5 +377,13 @@ public static int sampleFromDistribution(double[] distribution, Random rng) {
312
377
//Should be extremely unlikely to happen if distribution is a valid probability distribution
313
378
throw new IllegalArgumentException ("Distribution is invalid? d=" + d + ", sum=" + sum );
314
379
}
380
+ private static String getMelodiesFileNameFromURLPath (String midiFileZipFileUrlPath ) {
381
+ if (!(midiFileZipFileUrlPath .endsWith (".zip" ) || midiFileZipFileUrlPath .endsWith (".ZIP" ))) {
382
+ throw new IllegalStateException ("zipFilePath must end with .zip" );
383
+ }
384
+ midiFileZipFileUrlPath = midiFileZipFileUrlPath .replace ('\\' ,'/' );
385
+ String fileName = midiFileZipFileUrlPath .substring (midiFileZipFileUrlPath .lastIndexOf ("/" ) + 1 );
386
+ return fileName + ".txt" ;
387
+ }
315
388
}
316
389
0 commit comments