Skip to content

Commit bb104c7

Browse files
committed
Clean up code; add better documentation for how to use other midi files; download zipped midi files and extract melody strings to a temporary file, if that file doesn't already exist
1 parent 96e51a9 commit bb104c7

File tree

7 files changed

+354
-203
lines changed

7 files changed

+354
-203
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/wip/advanced/modelling/melodl4j/MelodyModelingExample.java

Lines changed: 82 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,34 +31,45 @@
3131
import org.deeplearning4j.nn.weights.WeightInit;
3232
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
3333
import org.deeplearning4j.util.ModelSerializer;
34+
import org.nd4j.common.util.ArchiveUtils;
3435
import org.nd4j.linalg.activations.Activation;
3536
import org.nd4j.linalg.api.ndarray.INDArray;
3637
import org.nd4j.linalg.dataset.DataSet;
3738
import org.nd4j.linalg.factory.Nd4j;
3839
import org.nd4j.linalg.learning.config.RmsProp;
3940
import org.nd4j.linalg.lossfunctions.LossFunctions;
4041

42+
import javax.sound.midi.InvalidMidiDataException;
4143
import java.io.*;
4244
import java.net.URL;
4345
import java.nio.charset.Charset;
46+
import java.nio.file.Files;
47+
import java.nio.file.Path;
4448
import java.util.ArrayList;
4549
import java.util.List;
4650
import java.util.Random;
51+
import java.util.zip.ZipEntry;
52+
import java.util.zip.ZipInputStream;
4753

4854
/**
4955
* LSTM Symbolic melody modelling example, to compose music from symbolic melodies extracted from MIDI.
50-
* Based closely on LSTMCharModellingExample.java.
56+
* LSTM logic is based closely on LSTMCharModellingExample.java.
5157
* See the README file in this directory for documentation.
5258
*
5359
* @author Alex Black, Donald A. Smith.
5460
*/
5561
public class MelodyModelingExample {
56-
final static String inputSymbolicMelodiesFilename = "bach-melodies-input.txt";
57-
// Examples: bach-melodies-input.txt, beatles-melodies-input.txt , pop-melodies-input.txt (large)
62+
// If you want to change the MIDI files used in learning, create a zip file containing your MIDI
63+
// files and replace the following path. For example, you might use something like:
64+
//final static String midiFileZipFileUrlPath = "file:d:/music/midi/classical-midi.zip";
65+
final static String midiFileZipFileUrlPath = "http://waliberals.org/truthsite/music/bach-midi.zip";
5866

59-
final static String tmpDir = System.getProperty("java.io.tmpdir");
67+
// For example "bach-midi.txt"
68+
final static String inputSymbolicMelodiesFilename = getMelodiesFileNameFromURLPath(midiFileZipFileUrlPath);
6069

61-
final static String symbolicMelodiesInputFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename; // Point to melodies created by MidiMelodyExtractor.java
70+
// Examples: bach-melodies-input.txt, beatles-melodies-input.txt , pop-melodies-input.txt (large)
71+
final static String tmpDir = System.getProperty("java.io.tmpdir");
72+
final static String inputSymbolicMelodiesFilePath = tmpDir + "/" + inputSymbolicMelodiesFilename; // Point to melodies created by MidiMelodyExtractor.java
6273
final static String composedMelodiesOutputFilePath = tmpDir + "/composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file.
6374

6475
//final static String symbolicMelodiesInputFilePath = "D:/tmp/bach-melodies.txt";
@@ -73,6 +84,8 @@ public static void main(String[] args) throws Exception {
7384
generationInitialization = args[1];
7485
}
7586

87+
makeMidiStringFileIfNecessary();
88+
7689
int lstmLayerSize = 200; //Number of units in each LSTM layer
7790
int miniBatchSize = 32; //Size of mini batch to use when training
7891
int exampleLength = 500; //1000; //Length of each training example sequence to use.
@@ -199,36 +212,78 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
199212
System.exit(0);
200213
}
201214

202-
public static void makeSureFileIsInTmpDir(String filename) {
215+
public static File makeSureFileIsInTmpDir(String urlString) throws IOException {
216+
final URL url = new URL(urlString);
217+
final String filename = urlString.substring(1+urlString.lastIndexOf("/"));
203218
final File f = new File(tmpDir + "/" + filename);
204-
if (!f.exists()) {
205-
URL url = null;
206-
try {
207-
url = new URL("http://truthsite.org/music/" + filename);
208-
FileUtils.copyURLToFile(url, f);
209-
} catch (Exception exc) {
210-
System.err.println("Error copying " + url + " to " + f);
211-
throw new RuntimeException(exc);
212-
}
219+
if (f.exists()) {
220+
System.out.println("Using existing " + f.getAbsolutePath());
221+
} else {
222+
FileUtils.copyURLToFile(url, f);
213223
if (!f.exists()) {
214224
throw new RuntimeException(f.getAbsolutePath() + " does not exist");
215225
}
216226
System.out.println("File downloaded to " + f.getAbsolutePath());
217-
} else {
218-
System.out.println("Using existing text file at " + f.getAbsolutePath());
219227
}
228+
return f;
220229
}
221230

231+
//https://stackoverflow.com/questions/10633595/java-zip-how-to-unzip-folder
232+
public static void unzip(File zipFile, File targetDirFile) throws IOException {
233+
InputStream is = new FileInputStream(zipFile);
234+
Path targetDir = targetDirFile.toPath();
235+
targetDir = targetDir.toAbsolutePath();
236+
try (ZipInputStream zipIn = new ZipInputStream(is)) {
237+
for (ZipEntry ze; (ze = zipIn.getNextEntry()) != null; ) {
238+
Path resolvedPath = targetDir.resolve(ze.getName()).normalize();
239+
if (!resolvedPath.startsWith(targetDir)) {
240+
// see: https://snyk.io/research/zip-slip-vulnerability
241+
throw new RuntimeException("Entry with an illegal path: "
242+
+ ze.getName());
243+
}
244+
if (ze.isDirectory()) {
245+
Files.createDirectories(resolvedPath);
246+
} else {
247+
Files.createDirectories(resolvedPath.getParent());
248+
Files.copy(zipIn, resolvedPath);
249+
}
250+
}
251+
}
252+
is.close();
253+
}
254+
private static void makeMidiStringFileIfNecessary() throws IOException, InvalidMidiDataException {
255+
final File inputMelodiesFile = new File(inputSymbolicMelodiesFilePath);
256+
if (inputMelodiesFile.exists() && inputMelodiesFile.length()>1000) {
257+
System.out.println("Using existing " + inputSymbolicMelodiesFilePath);
258+
return;
259+
}
260+
final File midiZipFile = makeSureFileIsInTmpDir(midiFileZipFileUrlPath);
261+
final String midiZipFileName = midiZipFile.getName();
262+
final String midiZipFileNameWithoutSuffix = midiZipFileName.substring(0,midiZipFileName.lastIndexOf("."));
263+
final File outputDirectoryFile = new File(tmpDir,midiZipFileNameWithoutSuffix);
264+
final String outputDirectoryPath = outputDirectoryFile.getAbsolutePath();
265+
if (!outputDirectoryFile.exists()) {
266+
outputDirectoryFile.mkdir();
267+
}
268+
if (!outputDirectoryFile.exists() || !outputDirectoryFile.isDirectory()) {
269+
throw new IllegalStateException(outputDirectoryFile + " is not a directory or can't be created");
270+
}
271+
final PrintStream printStream = new PrintStream(inputSymbolicMelodiesFilePath);
272+
System.out.println("Unzipping "+ midiZipFile.getAbsolutePath() + " to " + outputDirectoryPath);
273+
unzip(midiZipFile, outputDirectoryFile);
274+
System.out.println("Extracted " + midiZipFile.getAbsolutePath() + " to " + outputDirectoryPath);
275+
MidiMelodyExtractor.processDirectoryAndWriteMelodyFile(outputDirectoryFile,inputMelodiesFile);
276+
printStream.close();
277+
}
222278
/**
223279
* Sets up and return a simple DataSetIterator that does vectorization based on the melody sample.
224280
*
225281
* @param miniBatchSize Number of text segments in each training mini-batch
226282
* @param sequenceLength Number of characters in each text segment.
227283
*/
228284
public static CharacterIterator getMidiIterator(int miniBatchSize, int sequenceLength) throws Exception {
229-
makeSureFileIsInTmpDir(inputSymbolicMelodiesFilename);
230285
final char[] validCharacters = MelodyStrings.allValidCharacters.toCharArray(); //Which characters are allowed? Others will be removed
231-
return new CharacterIterator(symbolicMelodiesInputFilePath, Charset.forName("UTF-8"),
286+
return new CharacterIterator(inputSymbolicMelodiesFilePath, Charset.forName("UTF-8"),
232287
miniBatchSize, sequenceLength, validCharacters, new Random(12345), MelodyStrings.COMMENT_STRING);
233288
}
234289

@@ -312,5 +367,13 @@ public static int sampleFromDistribution(double[] distribution, Random rng) {
312367
//Should be extremely unlikely to happen if distribution is a valid probability distribution
313368
throw new IllegalArgumentException("Distribution is invalid? d=" + d + ", sum=" + sum);
314369
}
370+
private static String getMelodiesFileNameFromURLPath(String midiFileZipFileUrlPath) {
371+
if (!(midiFileZipFileUrlPath.endsWith(".zip") || midiFileZipFileUrlPath.endsWith(".ZIP"))) {
372+
throw new IllegalStateException("zipFilePath must end with .zip");
373+
}
374+
midiFileZipFileUrlPath = midiFileZipFileUrlPath.replace('\\','/');
375+
String fileName = midiFileZipFileUrlPath.substring(midiFileZipFileUrlPath.lastIndexOf("/") + 1);
376+
return fileName + ".txt";
377+
}
315378
}
316379

dl4j-examples/src/main/java/org/deeplearning4j/examples/wip/advanced/modelling/melodl4j/MelodyStrings.java

Lines changed: 96 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,102 @@
2626
*
2727
* @author Don Smith
2828
*/
29+
2930
public class MelodyStrings {
3031
public static final String COMMENT_STRING = "//";
3132
// The following strings are used to build the symbolic representation of a melody
3233
// The next two strings contain chars used to indicate pitch deltas.
33-
public static final String noteGapCharsPositive = "0123456789abc"; // A pitch delta of "0" indicates delta=0.
34-
public static final String noteGapCharsNegative = "ABCDEFGHIJKLM"; // A pitch delta of "A" indicates delta=-1.
35-
// R is used to indicate the beginning of a rest
34+
// We use this ordering of characters so that the pitch gap order is the same as the ASCII order.
35+
public static final char lowestPitchGapChar = 'A';
36+
public static final char REST_CHAR = ' ';
37+
38+
// We allow pitch gaps between -12 and +12, inclusive.
39+
// If you want to change the allowed gap, you will have to change the characters in PITCH_GAP_CHARS_NEGATIVE and PITCH_GAP_CHARS_POSITIVE
40+
41+
public static int MAX_POSITIVE_PITCH_GAP = 12;
42+
// The order of characters below is intended to simplify the learning of pitch gaps, because increasing
43+
// characters correspond to increasing pitch gaps. The string must end with 'A'.
44+
public static final String PITCH_GAP_CHARS_NEGATIVE = "LKJIHGFEDCBA"; // "L" indicates delta=-1. "K" indicates -2,...
45+
46+
public static final char FIRST_PITCH_CHAR_NEGATIVE = PITCH_GAP_CHARS_NEGATIVE.charAt(0);
47+
48+
// There are thirteen chars in pitchGapCharsPositive because the first one ('M') indicates a zero pitch gap.
49+
public static final String PITCH_GAP_CHARS_POSITIVE = "MNOPQRSTUVWXY"; // "M" indicates delta=0. "N" indicates 1, ...
50+
public static final char ZERO_PITCH_DELTA_CHAR = 'M';
51+
public static final char MAX_POSITIVE_PITCH_DELTA_CHAR = 'Y';
52+
53+
// durationDeltaParts determines how short durations can get. The shortest duration is 1/durationDeltaParts
54+
// as long as the average note length in the piece.
3655
public static int durationDeltaParts = 8;
37-
public static final String durationChars = "defghijkmnopqrstuvwzyz!@#$%^&*-_"; // 32 divisions. We omit lower-case L to avoid confusion with one.
38-
// 12345678901234567890123456789012
56+
public static final String DURATION_CHARS = "]^_`abcdefghijklmnopqrstuvwxyz{|"; // 32 divisions, in ASCII order
57+
58+
public static final char FIRST_DURATION_CHAR = ']';
3959
public static final String allValidCharacters = getValidCharacters();
4060
// 13+13+1+32 = 59 possible characters.
41-
// 'd' indicates the smallest pitch duration allowed (typically a 1/32 note or so).
42-
// 'e' is a duration twice that of 'd'
43-
// 'f' is a duration three times that of 'd', etc.
44-
// If there is a rest between notes, we append 'R' followed by a char for the duration of the rest.
45-
public static final char restChar = 'R';
61+
// ']' indicates the smallest pitch duration allowed (typically a 1/32 note or so).
62+
// '^' is a duration twice that of ']'
63+
// '_' is a duration three times that of ']', etc.
64+
// If there is a rest between notes, we append ' ' followed by a char for the duration of the rest.
4665

4766
/**
4867
* @return characters that may occur in a valid melody string
4968
*/
5069
private static String getValidCharacters() {
5170
StringBuilder sb = new StringBuilder();
52-
sb.append(noteGapCharsPositive);
53-
sb.append(noteGapCharsNegative);
54-
sb.append(durationChars);
55-
sb.append('R');
71+
sb.append(PITCH_GAP_CHARS_POSITIVE);
72+
sb.append(PITCH_GAP_CHARS_NEGATIVE);
73+
sb.append(DURATION_CHARS);
74+
sb.append(REST_CHAR);
5675
return sb.toString();
5776
}
77+
public static boolean isValidMelodyString(String string) {
78+
for(int i=0;i<string.length();i++) {
79+
if (i%2==0) {
80+
if (!isDurationChar(string.charAt(i))) {
81+
return false;
82+
}
83+
} else {
84+
if (!isPitchCharOrRest(string.charAt(i))) {
85+
return false;
86+
}
87+
}
88+
}
89+
return true;
90+
}
5891

92+
public static int getPitchDelta(final char ch) {
93+
if (ch >= ZERO_PITCH_DELTA_CHAR && ch <= MAX_POSITIVE_PITCH_DELTA_CHAR) {
94+
return ch - ZERO_PITCH_DELTA_CHAR;
95+
}
96+
if (ch >= 'A' && ch <= FIRST_PITCH_CHAR_NEGATIVE) {
97+
return - (1 + (FIRST_PITCH_CHAR_NEGATIVE - ch));
98+
}
99+
return 0;
100+
}
101+
public static char getCharForPitchGap(int pitchGap) {
102+
while (pitchGap>MAX_POSITIVE_PITCH_GAP) {
103+
pitchGap -= MAX_POSITIVE_PITCH_GAP;
104+
}
105+
while (pitchGap < -MAX_POSITIVE_PITCH_GAP) {
106+
pitchGap += MAX_POSITIVE_PITCH_GAP;
107+
}
108+
return (char) (ZERO_PITCH_DELTA_CHAR + pitchGap);
109+
}
110+
111+
public static int getDurationInTicks(char ch, int resolutionDelta) {
112+
int diff = Math.max(0,ch - FIRST_DURATION_CHAR);
113+
return diff * resolutionDelta;
114+
}
115+
116+
public static boolean isDurationChar(char ch) {
117+
return ch>=']' && ch <= '|';
118+
}
119+
public static boolean isPitchCharOrRest(char ch) {
120+
return ch == REST_CHAR || ch >= 'A' && ch <= 'Z';
121+
}
122+
public static boolean isPitchChar(char ch) {
123+
return ch >= 'A' && ch <= 'Z';
124+
}
59125
public static String convertToMelodyString(List<Note> noteSequence) {
60126
double averageNoteDuration = computeAverageDuration(noteSequence);
61127
double durationDelta = averageNoteDuration / durationDeltaParts;
@@ -70,15 +136,15 @@ public static String convertToMelodyString(List<Note> noteSequence) {
70136
long restDuration = note.getStartTick() - previousNote.getEndTick();
71137
if (restDuration > 0) {
72138
char restDurationChar = computeDurationChar(restDuration, durationDelta);
73-
sb.append(restChar);
139+
sb.append(REST_CHAR);
74140
sb.append(restDurationChar);
75141
}
76142
int pitchGap = note.getPitch() - previousNote.getPitch();
77-
while (pitchGap >= noteGapCharsPositive.length()) {
78-
pitchGap -= noteGapCharsPositive.length();
143+
while (pitchGap > MAX_POSITIVE_PITCH_GAP) {
144+
pitchGap -= MAX_POSITIVE_PITCH_GAP;
79145
}
80-
while (pitchGap < -noteGapCharsNegative.length()) {
81-
pitchGap += noteGapCharsNegative.length();
146+
while (pitchGap < -MAX_POSITIVE_PITCH_GAP) {
147+
pitchGap += MAX_POSITIVE_PITCH_GAP;
82148
}
83149
sb.append(getCharForPitchGap(pitchGap));
84150
long noteDuration = note.getDurationInTicks();
@@ -87,20 +153,25 @@ public static String convertToMelodyString(List<Note> noteSequence) {
87153
}
88154
previousNote = note;
89155
}
90-
return sb.toString();
156+
String result= sb.toString();
157+
if (!isValidMelodyString(result)) {
158+
System.err.println("Invalid melody string: " + result);
159+
}
160+
return result;
91161
}
92162

93-
private static char getCharForPitchGap(int pitchGap) {
94-
return pitchGap >= 0 ? noteGapCharsPositive.charAt(pitchGap) : noteGapCharsNegative.charAt(-1 - pitchGap);
95-
}
96163

97164
private static char computeDurationChar(long duration, double durationDelta) {
98-
int times = Math.min((int) Math.round(duration / durationDelta), durationChars.length() - 1);
165+
int times = Math.min((int) Math.round(duration / durationDelta), DURATION_CHARS.length() - 1);
99166
if (times < 0) {
100167
System.err.println("WARNING: Duration = " + duration);
101168
times = 0;
102169
}
103-
return durationChars.charAt(times);
170+
char ch = DURATION_CHARS.charAt(times);
171+
if (!isDurationChar(ch)) {
172+
throw new IllegalStateException("Invalid duration char " + ch + " for duration " + duration + ", " + durationDelta);
173+
}
174+
return ch;
104175
}
105176

106177
private static double computeAverageDuration(List<Note> noteSequence) {

0 commit comments

Comments
 (0)