11/******************************************************************************* 
2-  * 
3-  * 
42 * 
53 * This program and the accompanying materials are made available under the 
64 * terms of the Apache License, Version 2.0 which is available at 
1715 * SPDX-License-Identifier: Apache-2.0 
1816 ******************************************************************************/ 
1917
20- package  org .deeplearning4j .examples .wip . advanced .modelling .melodl4j ;
18+ package  org .deeplearning4j .examples .advanced .modelling . charmodelling .melodl4j ;
2119
2220import  org .apache .commons .io .FileUtils ;
2321import  org .deeplearning4j .examples .advanced .modelling .charmodelling .utils .CharacterIterator ;
3129import  org .deeplearning4j .nn .weights .WeightInit ;
3230import  org .deeplearning4j .optimize .listeners .ScoreIterationListener ;
3331import  org .deeplearning4j .util .ModelSerializer ;
32+ import  org .nd4j .common .util .ArchiveUtils ;
3433import  org .nd4j .linalg .activations .Activation ;
3534import  org .nd4j .linalg .api .ndarray .INDArray ;
3635import  org .nd4j .linalg .dataset .DataSet ;
3736import  org .nd4j .linalg .factory .Nd4j ;
37+ import  org .nd4j .linalg .learning .config .Adam ;
3838import  org .nd4j .linalg .learning .config .RmsProp ;
3939import  org .nd4j .linalg .lossfunctions .LossFunctions ;
4040
41+ import  javax .sound .midi .InvalidMidiDataException ;
4142import  java .io .*;
4243import  java .net .URL ;
4344import  java .nio .charset .Charset ;
45+ import  java .nio .file .Files ;
46+ import  java .nio .file .Path ;
47+ import  java .text .NumberFormat ;
4448import  java .util .ArrayList ;
4549import  java .util .List ;
4650import  java .util .Random ;
51+ import  java .util .zip .ZipEntry ;
52+ import  java .util .zip .ZipInputStream ;
4753
4854/** 
4955 * LSTM  Symbolic melody modelling example, to compose music from symbolic melodies extracted from MIDI. 
50-  * Based  closely on LSTMCharModellingExample.java. 
56+  * LSTM logic is based  closely on LSTMCharModellingExample.java. 
5157 * See the README file in this directory for documentation. 
5258 * 
5359 * @author Alex Black, Donald A. Smith. 
5460 */ 
5561public  class  MelodyModelingExample  {
56-     final  static  String  inputSymbolicMelodiesFilename  = "bach-melodies-input.txt" ;
57-     // Examples:  bach-melodies-input.txt, beatles-melodies-input.txt ,  pop-melodies-input.txt (large) 
62+     // If you want to change the MIDI files used in learning, create a zip file containing your MIDI 
63+     // files and replace the following path.  For example, you might use something like: 
64+     //final static String midiFileZipFileUrlPath = "file:d:/music/midi/classical-midi.zip"; 
65+     final  static  String  midiFileZipFileUrlPath  = "http://waliberals.org/truthsite/music/bach-midi.zip" ;
5866
59-     final  static  String  tmpDir  = System .getProperty ("java.io.tmpdir" );
67+     // For example "bach-midi.txt" 
68+     final  static  String  inputSymbolicMelodiesFilename  = getMelodiesFileNameFromURLPath (midiFileZipFileUrlPath );
6069
61-     final  static  String  symbolicMelodiesInputFilePath  = tmpDir  + "/"  + inputSymbolicMelodiesFilename ;  // Point to melodies created by MidiMelodyExtractor.java 
70+     // Examples:  bach-melodies-input.txt, beatles-melodies-input.txt ,  pop-melodies-input.txt (large) 
71+     final  static  String  tmpDir  = System .getProperty ("java.io.tmpdir" );
72+     final  static  String  inputSymbolicMelodiesFilePath  = tmpDir  + "/"  + inputSymbolicMelodiesFilename ;  // Point to melodies created by MidiMelodyExtractor.java 
6273    final  static  String  composedMelodiesOutputFilePath  = tmpDir  + "/composition.txt" ; // You can listen to these melodies by running PlayMelodyStrings.java against this file. 
6374
6475    //final static String symbolicMelodiesInputFilePath = "D:/tmp/bach-melodies.txt"; 
6576    //final static String composedMelodiesOutputFilePath = tmpDir + "/bach-composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file. 
66- 
77+     final  static  NumberFormat  numberFormat  = NumberFormat .getNumberInstance ();
78+     static  {
79+         numberFormat .setMinimumFractionDigits (1 );
80+         numberFormat .setMaximumFractionDigits (1 );
81+     }
6782    //.... 
6883    public  static  void  main (String [] args ) throws  Exception  {
6984        String  loadNetworkPath  = null ; //"/tmp/MelodyModel-bach.zip"; //null; 
@@ -73,6 +88,8 @@ public static void main(String[] args) throws Exception {
7388            generationInitialization  = args [1 ];
7489        }
7590
91+         makeMidiStringFileIfNecessary ();
92+ 
7693        int  lstmLayerSize  = 200 ;                    //Number of units in each LSTM layer 
7794        int  miniBatchSize  = 32 ;                     //Size of mini batch to use when training 
7895        int  exampleLength  = 500 ; //1000; 		    //Length of each training example sequence to use. 
@@ -107,9 +124,10 @@ public static void main(String[] args) throws Exception {
107124
108125        //Set up network configuration: 
109126        MultiLayerConfiguration  conf  = new  NeuralNetConfiguration .Builder ()
110-             .updater (new  RmsProp (0.1 ))
111-             .seed (12345 )
112-             .l2 (0.001 )
127+             //.updater(new RmsProp(0.1)) 
128+             .updater (new  Adam (0.005 ))
129+             .seed (System .currentTimeMillis ()) // So each run generates new melodies 
130+             .l2 (0.0001 )
113131            .weightInit (WeightInit .XAVIER )
114132            .list ()
115133            .layer (0 , new  LSTM .Builder ().nIn (iter .inputColumns ()).nOut (lstmLayerSize )
@@ -123,7 +141,6 @@ public static void main(String[] args) throws Exception {
123141            .backpropType (BackpropType .TruncatedBPTT ).tBPTTForwardLength (tbpttLength ).tBPTTBackwardLength (tbpttLength )
124142            .build ();
125143
126- 
127144        learn (miniBatchSize , exampleLength , numEpochs , generateSamplesEveryNMinibatches , nSamplesToGenerate , nCharactersToSample , generationInitialization , rng , startTime , iter , conf );
128145    }
129146
@@ -154,6 +171,7 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
154171        // order, so that the best melodies are at the start of the file. 
155172        //Do training, and then generate and print samples from network 
156173        int  miniBatchNumber  = 0 ;
174+         long  lastTime  = System .currentTimeMillis ();
157175        for  (int  epoch  = 0 ; epoch  < numEpochs ; epoch ++) {
158176            System .out .println ("Starting epoch "  + epoch );
159177            while  (iter .hasNext ()) {
@@ -176,12 +194,19 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
176194                }
177195            }
178196            iter .reset ();    //Reset iterator for another epoch 
197+             final  double  secondsForEpoch  = 0.001  * (System .currentTimeMillis () - startTime );
198+             final  long  now  = System .currentTimeMillis ();
179199            if  (melodies .size () > 0 ) {
180200                String  melody  = melodies .get (melodies .size () - 1 );
181201                int  seconds  = 25 ;
182202                System .out .println ("\n First "  + seconds  + " seconds of "  + melody );
183203                PlayMelodyStrings .playMelody (melody , seconds );
184204            }
205+             double  seconds  = 0.001 *(now  - lastTime );
206+             lastTime  = now ;
207+             System .out .println ("\n Epoch "  + epoch  + " time in seconds: "  + numberFormat .format (seconds ));
208+             // 531.9 for GPU GTX 1070 
209+             // 821.4 for CPU i7-6700K @ 4GHZ 
185210        }
186211        int  indexOfLastPeriod  = inputSymbolicMelodiesFilename .lastIndexOf ('.' );
187212        String  saveFileName  = inputSymbolicMelodiesFilename .substring (0 , indexOfLastPeriod  > 0  ? indexOfLastPeriod  : inputSymbolicMelodiesFilename .length ());
@@ -193,42 +218,82 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
193218            printWriter .println (melodies .get (i ));
194219        }
195220        printWriter .close ();
196-         double  seconds  = 0.001  * (System .currentTimeMillis () - startTime );
197221
198-         System .out .println ("\n \n Example complete in "  + seconds  + " seconds" );
199222        System .exit (0 );
200223    }
201224
202-     public  static  void  makeSureFileIsInTmpDir (String  filename ) {
225+     public  static  File  makeSureFileIsInTmpDir (String  urlString ) throws  IOException  {
226+         final  URL  url  = new  URL (urlString );
227+         final  String  filename  = urlString .substring (1 +urlString .lastIndexOf ("/" ));
203228        final  File  f  = new  File (tmpDir  + "/"  + filename );
204-         if  (!f .exists ()) {
205-             URL  url  = null ;
206-             try  {
207-                 url  = new  URL ("http://truthsite.org/music/"  + filename );
208-                 FileUtils .copyURLToFile (url , f );
209-             } catch  (Exception  exc ) {
210-                 System .err .println ("Error copying "  + url  + " to "  + f );
211-                 throw  new  RuntimeException (exc );
212-             }
229+         if  (f .exists ()) {
230+             System .out .println ("Using existing "  + f .getAbsolutePath ());
231+         } else  {
232+             FileUtils .copyURLToFile (url , f );
213233            if  (!f .exists ()) {
214234                throw  new  RuntimeException (f .getAbsolutePath () + " does not exist" );
215235            }
216236            System .out .println ("File downloaded to "  + f .getAbsolutePath ());
217-         } else  {
218-             System .out .println ("Using existing text file at "  + f .getAbsolutePath ());
219237        }
238+         return  f ;
220239    }
221240
241+     //https://stackoverflow.com/questions/10633595/java-zip-how-to-unzip-folder 
242+     public  static  void  unzip (File  zipFile , File  targetDirFile ) throws  IOException  {
243+         InputStream  is  =  new  FileInputStream (zipFile );
244+         Path  targetDir  = targetDirFile .toPath ();
245+         targetDir  = targetDir .toAbsolutePath ();
246+         try  (ZipInputStream  zipIn  = new  ZipInputStream (is )) {
247+             for  (ZipEntry  ze ; (ze  = zipIn .getNextEntry ()) != null ; ) {
248+                 Path  resolvedPath  = targetDir .resolve (ze .getName ()).normalize ();
249+                 if  (!resolvedPath .startsWith (targetDir )) {
250+                     // see: https://snyk.io/research/zip-slip-vulnerability 
251+                     throw  new  RuntimeException ("Entry with an illegal path: " 
252+                             + ze .getName ());
253+                 }
254+                 if  (ze .isDirectory ()) {
255+                     Files .createDirectories (resolvedPath );
256+                 } else  {
257+                     Files .createDirectories (resolvedPath .getParent ());
258+                     Files .copy (zipIn , resolvedPath );
259+                 }
260+             }
261+         }
262+         is .close ();
263+     }
264+     private  static  void  makeMidiStringFileIfNecessary () throws  IOException , InvalidMidiDataException  {
265+         final  File  inputMelodiesFile  = new  File (inputSymbolicMelodiesFilePath );
266+         if  (inputMelodiesFile .exists () && inputMelodiesFile .length ()>1000 ) {
267+             System .out .println ("Using existing "  + inputSymbolicMelodiesFilePath );
268+             return ;
269+         }
270+         final  File  midiZipFile  = makeSureFileIsInTmpDir (midiFileZipFileUrlPath );
271+         final  String  midiZipFileName  = midiZipFile .getName ();
272+         final  String  midiZipFileNameWithoutSuffix  = midiZipFileName .substring (0 ,midiZipFileName .lastIndexOf ("." ));
273+         final  File  outputDirectoryFile   = new  File (tmpDir ,midiZipFileNameWithoutSuffix );
274+         final  String  outputDirectoryPath  = outputDirectoryFile .getAbsolutePath ();
275+         if  (!outputDirectoryFile .exists ()) {
276+             outputDirectoryFile .mkdir ();
277+         }
278+         if  (!outputDirectoryFile .exists () || !outputDirectoryFile .isDirectory ()) {
279+             throw  new  IllegalStateException (outputDirectoryFile  + " is not a directory or can't be created" );
280+         }
281+         final  PrintStream  printStream  = new  PrintStream (inputSymbolicMelodiesFilePath );
282+         System .out .println ("Unzipping " + midiZipFile .getAbsolutePath () + " to "  + outputDirectoryPath );
283+         unzip (midiZipFile , outputDirectoryFile );
284+         System .out .println ("Extracted "  + midiZipFile .getAbsolutePath () + " to "  + outputDirectoryPath );
285+         MidiMelodyExtractor .processDirectoryAndWriteMelodyFile (outputDirectoryFile ,inputMelodiesFile );
286+         printStream .close ();
287+     }
222288    /** 
223289     * Sets up and return a simple DataSetIterator that does vectorization based on the melody sample. 
224290     * 
225291     * @param miniBatchSize  Number of text segments in each training mini-batch 
226292     * @param sequenceLength Number of characters in each text segment. 
227293     */ 
228294    public  static  CharacterIterator  getMidiIterator (int  miniBatchSize , int  sequenceLength ) throws  Exception  {
229-         makeSureFileIsInTmpDir (inputSymbolicMelodiesFilename );
230295        final  char [] validCharacters  = MelodyStrings .allValidCharacters .toCharArray (); //Which characters are allowed? Others will be removed 
231-         return  new  CharacterIterator (symbolicMelodiesInputFilePath , Charset .forName ("UTF-8" ),
296+         return  new  CharacterIterator (inputSymbolicMelodiesFilePath , Charset .forName ("UTF-8" ),
232297            miniBatchSize , sequenceLength , validCharacters , new  Random (12345 ), MelodyStrings .COMMENT_STRING );
233298    }
234299
@@ -312,5 +377,13 @@ public static int sampleFromDistribution(double[] distribution, Random rng) {
312377        //Should be extremely unlikely to happen if distribution is a valid probability distribution 
313378        throw  new  IllegalArgumentException ("Distribution is invalid? d="  + d  + ", sum="  + sum );
314379    }
380+     private  static  String  getMelodiesFileNameFromURLPath (String  midiFileZipFileUrlPath ) {
381+         if  (!(midiFileZipFileUrlPath .endsWith (".zip" ) || midiFileZipFileUrlPath .endsWith (".ZIP" ))) {
382+             throw  new  IllegalStateException ("zipFilePath must end with .zip" );
383+         }
384+         midiFileZipFileUrlPath  = midiFileZipFileUrlPath .replace ('\\' ,'/' );
385+         String  fileName  = midiFileZipFileUrlPath .substring (midiFileZipFileUrlPath .lastIndexOf ("/" ) + 1 );
386+         return  fileName  + ".txt" ;
387+     }
315388}
316389
0 commit comments