Skip to content

Commit d0a5b20

Browse files
authored
Merge pull request #1070 from DonaldAlan/meldoy4j-review-971
Meldoy4j review 971
2 parents 96e51a9 + d7adac9 commit d0a5b20

File tree

9 files changed

+439
-251
lines changed

9 files changed

+439
-251
lines changed
+102-29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
/*******************************************************************************
2-
*
3-
*
42
*
53
* This program and the accompanying materials are made available under the
64
* terms of the Apache License, Version 2.0 which is available at
@@ -17,7 +15,7 @@
1715
* SPDX-License-Identifier: Apache-2.0
1816
******************************************************************************/
1917

20-
package org.deeplearning4j.examples.wip.advanced.modelling.melodl4j;
18+
package org.deeplearning4j.examples.advanced.modelling.charmodelling.melodl4j;
2119

2220
import org.apache.commons.io.FileUtils;
2321
import org.deeplearning4j.examples.advanced.modelling.charmodelling.utils.CharacterIterator;
@@ -31,39 +29,56 @@
3129
import org.deeplearning4j.nn.weights.WeightInit;
3230
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
3331
import org.deeplearning4j.util.ModelSerializer;
32+
import org.nd4j.common.util.ArchiveUtils;
3433
import org.nd4j.linalg.activations.Activation;
3534
import org.nd4j.linalg.api.ndarray.INDArray;
3635
import org.nd4j.linalg.dataset.DataSet;
3736
import org.nd4j.linalg.factory.Nd4j;
37+
import org.nd4j.linalg.learning.config.Adam;
3838
import org.nd4j.linalg.learning.config.RmsProp;
3939
import org.nd4j.linalg.lossfunctions.LossFunctions;
4040

41+
import javax.sound.midi.InvalidMidiDataException;
4142
import java.io.*;
4243
import java.net.URL;
4344
import java.nio.charset.Charset;
45+
import java.nio.file.Files;
46+
import java.nio.file.Path;
47+
import java.text.NumberFormat;
4448
import java.util.ArrayList;
4549
import java.util.List;
4650
import java.util.Random;
51+
import java.util.zip.ZipEntry;
52+
import java.util.zip.ZipInputStream;
4753

4854
/**
4955
* LSTM Symbolic melody modelling example, to compose music from symbolic melodies extracted from MIDI.
50-
* Based closely on LSTMCharModellingExample.java.
56+
* LSTM logic is based closely on LSTMCharModellingExample.java.
5157
* See the README file in this directory for documentation.
5258
*
5359
* @author Alex Black, Donald A. Smith.
5460
*/
5561
public class MelodyModelingExample {
56-
final static String inputSymbolicMelodiesFilename = "bach-melodies-input.txt";
57-
// Examples: bach-melodies-input.txt, beatles-melodies-input.txt , pop-melodies-input.txt (large)
62+
// If you want to change the MIDI files used in learning, create a zip file containing your MIDI
63+
// files and replace the following path. For example, you might use something like:
64+
//final static String midiFileZipFileUrlPath = "file:d:/music/midi/classical-midi.zip";
65+
final static String midiFileZipFileUrlPath = "http://waliberals.org/truthsite/music/bach-midi.zip";
5866

59-
final static String tmpDir = System.getProperty("java.io.tmpdir");
67+
// For example "bach-midi.txt"
68+
final static String inputSymbolicMelodiesFilename = getMelodiesFileNameFromURLPath(midiFileZipFileUrlPath);
6069

61-
final static String symbolicMelodiesInputFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename; // Point to melodies created by MidiMelodyExtractor.java
70+
// Examples: bach-melodies-input.txt, beatles-melodies-input.txt , pop-melodies-input.txt (large)
71+
final static String tmpDir = System.getProperty("java.io.tmpdir");
72+
final static String inputSymbolicMelodiesFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename; // Point to melodies created by MidiMelodyExtractor.java
6273
final static String composedMelodiesOutputFilePath = tmpDir + "/composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file.
6374

6475
//final static String symbolicMelodiesInputFilePath = "D:/tmp/bach-melodies.txt";
6576
//final static String composedMelodiesOutputFilePath = tmpDir + "/bach-composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file.
66-
77+
final static NumberFormat numberFormat = NumberFormat.getNumberInstance();
78+
static {
79+
numberFormat.setMinimumFractionDigits(1);
80+
numberFormat.setMaximumFractionDigits(1);
81+
}
6782
//....
6883
public static void main(String[] args) throws Exception {
6984
String loadNetworkPath = null; //"/tmp/MelodyModel-bach.zip"; //null;
@@ -73,6 +88,8 @@ public static void main(String[] args) throws Exception {
7388
generationInitialization = args[1];
7489
}
7590

91+
makeMidiStringFileIfNecessary();
92+
7693
int lstmLayerSize = 200; //Number of units in each LSTM layer
7794
int miniBatchSize = 32; //Size of mini batch to use when training
7895
int exampleLength = 500; //1000; //Length of each training example sequence to use.
@@ -107,9 +124,10 @@ public static void main(String[] args) throws Exception {
107124

108125
//Set up network configuration:
109126
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
110-
.updater(new RmsProp(0.1))
111-
.seed(12345)
112-
.l2(0.001)
127+
//.updater(new RmsProp(0.1))
128+
.updater(new Adam(0.005))
129+
.seed(System.currentTimeMillis()) // So each run generates new melodies
130+
.l2(0.0001)
113131
.weightInit(WeightInit.XAVIER)
114132
.list()
115133
.layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize)
@@ -123,7 +141,6 @@ public static void main(String[] args) throws Exception {
123141
.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)
124142
.build();
125143

126-
127144
learn(miniBatchSize, exampleLength, numEpochs, generateSamplesEveryNMinibatches, nSamplesToGenerate, nCharactersToSample, generationInitialization, rng, startTime, iter, conf);
128145
}
129146

@@ -154,6 +171,7 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
154171
// order, so that the best melodies are at the start of the file.
155172
//Do training, and then generate and print samples from network
156173
int miniBatchNumber = 0;
174+
long lastTime = System.currentTimeMillis();
157175
for (int epoch = 0; epoch < numEpochs; epoch++) {
158176
System.out.println("Starting epoch " + epoch);
159177
while (iter.hasNext()) {
@@ -176,12 +194,19 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
176194
}
177195
}
178196
iter.reset(); //Reset iterator for another epoch
197+
final double secondsForEpoch = 0.001 * (System.currentTimeMillis() - startTime);
198+
final long now = System.currentTimeMillis();
179199
if (melodies.size() > 0) {
180200
String melody = melodies.get(melodies.size() - 1);
181201
int seconds = 25;
182202
System.out.println("\nFirst " + seconds + " seconds of " + melody);
183203
PlayMelodyStrings.playMelody(melody, seconds);
184204
}
205+
double seconds = 0.001*(now - lastTime);
206+
lastTime = now;
207+
System.out.println("\nEpoch " + epoch + " time in seconds: " + numberFormat.format(seconds));
208+
// 531.9 for GPU GTX 1070
209+
// 821.4 for CPU i7-6700K @ 4GHZ
185210
}
186211
int indexOfLastPeriod = inputSymbolicMelodiesFilename.lastIndexOf('.');
187212
String saveFileName = inputSymbolicMelodiesFilename.substring(0, indexOfLastPeriod > 0 ? indexOfLastPeriod : inputSymbolicMelodiesFilename.length());
@@ -193,42 +218,82 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
193218
printWriter.println(melodies.get(i));
194219
}
195220
printWriter.close();
196-
double seconds = 0.001 * (System.currentTimeMillis() - startTime);
197221

198-
System.out.println("\n\nExample complete in " + seconds + " seconds");
199222
System.exit(0);
200223
}
201224

202-
public static void makeSureFileIsInTmpDir(String filename) {
225+
public static File makeSureFileIsInTmpDir(String urlString) throws IOException {
226+
final URL url = new URL(urlString);
227+
final String filename = urlString.substring(1+urlString.lastIndexOf("/"));
203228
final File f = new File(tmpDir + "/" + filename);
204-
if (!f.exists()) {
205-
URL url = null;
206-
try {
207-
url = new URL("http://truthsite.org/music/" + filename);
208-
FileUtils.copyURLToFile(url, f);
209-
} catch (Exception exc) {
210-
System.err.println("Error copying " + url + " to " + f);
211-
throw new RuntimeException(exc);
212-
}
229+
if (f.exists()) {
230+
System.out.println("Using existing " + f.getAbsolutePath());
231+
} else {
232+
FileUtils.copyURLToFile(url, f);
213233
if (!f.exists()) {
214234
throw new RuntimeException(f.getAbsolutePath() + " does not exist");
215235
}
216236
System.out.println("File downloaded to " + f.getAbsolutePath());
217-
} else {
218-
System.out.println("Using existing text file at " + f.getAbsolutePath());
219237
}
238+
return f;
220239
}
221240

241+
//https://stackoverflow.com/questions/10633595/java-zip-how-to-unzip-folder
242+
public static void unzip(File zipFile, File targetDirFile) throws IOException {
243+
InputStream is = new FileInputStream(zipFile);
244+
Path targetDir = targetDirFile.toPath();
245+
targetDir = targetDir.toAbsolutePath();
246+
try (ZipInputStream zipIn = new ZipInputStream(is)) {
247+
for (ZipEntry ze; (ze = zipIn.getNextEntry()) != null; ) {
248+
Path resolvedPath = targetDir.resolve(ze.getName()).normalize();
249+
if (!resolvedPath.startsWith(targetDir)) {
250+
// see: https://snyk.io/research/zip-slip-vulnerability
251+
throw new RuntimeException("Entry with an illegal path: "
252+
+ ze.getName());
253+
}
254+
if (ze.isDirectory()) {
255+
Files.createDirectories(resolvedPath);
256+
} else {
257+
Files.createDirectories(resolvedPath.getParent());
258+
Files.copy(zipIn, resolvedPath);
259+
}
260+
}
261+
}
262+
is.close();
263+
}
264+
private static void makeMidiStringFileIfNecessary() throws IOException, InvalidMidiDataException {
265+
final File inputMelodiesFile = new File(inputSymbolicMelodiesFilePath);
266+
if (inputMelodiesFile.exists() && inputMelodiesFile.length()>1000) {
267+
System.out.println("Using existing " + inputSymbolicMelodiesFilePath);
268+
return;
269+
}
270+
final File midiZipFile = makeSureFileIsInTmpDir(midiFileZipFileUrlPath);
271+
final String midiZipFileName = midiZipFile.getName();
272+
final String midiZipFileNameWithoutSuffix = midiZipFileName.substring(0,midiZipFileName.lastIndexOf("."));
273+
final File outputDirectoryFile = new File(tmpDir,midiZipFileNameWithoutSuffix);
274+
final String outputDirectoryPath = outputDirectoryFile.getAbsolutePath();
275+
if (!outputDirectoryFile.exists()) {
276+
outputDirectoryFile.mkdir();
277+
}
278+
if (!outputDirectoryFile.exists() || !outputDirectoryFile.isDirectory()) {
279+
throw new IllegalStateException(outputDirectoryFile + " is not a directory or can't be created");
280+
}
281+
final PrintStream printStream = new PrintStream(inputSymbolicMelodiesFilePath);
282+
System.out.println("Unzipping "+ midiZipFile.getAbsolutePath() + " to " + outputDirectoryPath);
283+
unzip(midiZipFile, outputDirectoryFile);
284+
System.out.println("Extracted " + midiZipFile.getAbsolutePath() + " to " + outputDirectoryPath);
285+
MidiMelodyExtractor.processDirectoryAndWriteMelodyFile(outputDirectoryFile,inputMelodiesFile);
286+
printStream.close();
287+
}
222288
/**
223289
* Sets up and return a simple DataSetIterator that does vectorization based on the melody sample.
224290
*
225291
* @param miniBatchSize Number of text segments in each training mini-batch
226292
* @param sequenceLength Number of characters in each text segment.
227293
*/
228294
public static CharacterIterator getMidiIterator(int miniBatchSize, int sequenceLength) throws Exception {
229-
makeSureFileIsInTmpDir(inputSymbolicMelodiesFilename);
230295
final char[] validCharacters = MelodyStrings.allValidCharacters.toCharArray(); //Which characters are allowed? Others will be removed
231-
return new CharacterIterator(symbolicMelodiesInputFilePath, Charset.forName("UTF-8"),
296+
return new CharacterIterator(inputSymbolicMelodiesFilePath, Charset.forName("UTF-8"),
232297
miniBatchSize, sequenceLength, validCharacters, new Random(12345), MelodyStrings.COMMENT_STRING);
233298
}
234299

@@ -312,5 +377,13 @@ public static int sampleFromDistribution(double[] distribution, Random rng) {
312377
//Should be extremely unlikely to happen if distribution is a valid probability distribution
313378
throw new IllegalArgumentException("Distribution is invalid? d=" + d + ", sum=" + sum);
314379
}
380+
private static String getMelodiesFileNameFromURLPath(String midiFileZipFileUrlPath) {
381+
if (!(midiFileZipFileUrlPath.endsWith(".zip") || midiFileZipFileUrlPath.endsWith(".ZIP"))) {
382+
throw new IllegalStateException("zipFilePath must end with .zip");
383+
}
384+
midiFileZipFileUrlPath = midiFileZipFileUrlPath.replace('\\','/');
385+
String fileName = midiFileZipFileUrlPath.substring(midiFileZipFileUrlPath.lastIndexOf("/") + 1);
386+
return fileName + ".txt";
387+
}
315388
}
316389

0 commit comments

Comments
 (0)