Skip to content

Commit d7adac9

Browse files
committed
Show time elapsed per epoch
Signed-off-by: Don A. Smith <[email protected]>
1 parent 556df5a commit d7adac9

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/charmodelling/melodl4j/MelodyModelingExample.java

+14-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import java.nio.charset.Charset;
4545
import java.nio.file.Files;
4646
import java.nio.file.Path;
47+
import java.text.NumberFormat;
4748
import java.util.ArrayList;
4849
import java.util.List;
4950
import java.util.Random;
@@ -73,7 +74,11 @@ public class MelodyModelingExample {
7374

7475
//final static String symbolicMelodiesInputFilePath = "D:/tmp/bach-melodies.txt";
7576
//final static String composedMelodiesOutputFilePath = tmpDir + "/bach-composition.txt"; // You can listen to these melodies by running PlayMelodyStrings.java against this file.
76-
77+
final static NumberFormat numberFormat = NumberFormat.getNumberInstance();
78+
static {
79+
numberFormat.setMinimumFractionDigits(1);
80+
numberFormat.setMaximumFractionDigits(1);
81+
}
7782
//....
7883
public static void main(String[] args) throws Exception {
7984
String loadNetworkPath = null; //"/tmp/MelodyModel-bach.zip"; //null;
@@ -166,6 +171,7 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
166171
// order, so that the best melodies are at the start of the file.
167172
//Do training, and then generate and print samples from network
168173
int miniBatchNumber = 0;
174+
long lastTime = System.currentTimeMillis();
169175
for (int epoch = 0; epoch < numEpochs; epoch++) {
170176
System.out.println("Starting epoch " + epoch);
171177
while (iter.hasNext()) {
@@ -188,12 +194,19 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
188194
}
189195
}
190196
iter.reset(); //Reset iterator for another epoch
197+
final double secondsForEpoch = 0.001 * (System.currentTimeMillis() - startTime);
198+
final long now = System.currentTimeMillis();
191199
if (melodies.size() > 0) {
192200
String melody = melodies.get(melodies.size() - 1);
193201
int seconds = 25;
194202
System.out.println("\nFirst " + seconds + " seconds of " + melody);
195203
PlayMelodyStrings.playMelody(melody, seconds);
196204
}
205+
double seconds = 0.001*(now - lastTime);
206+
lastTime = now;
207+
System.out.println("\nEpoch " + epoch + " time in seconds: " + numberFormat.format(seconds));
208+
// 531.9 for GPU GTX 1070
209+
// 821.4 for CPU i7-6700K @ 4GHZ
197210
}
198211
int indexOfLastPeriod = inputSymbolicMelodiesFilename.lastIndexOf('.');
199212
String saveFileName = inputSymbolicMelodiesFilename.substring(0, indexOfLastPeriod > 0 ? indexOfLastPeriod : inputSymbolicMelodiesFilename.length());
@@ -205,9 +218,7 @@ private static void learn(int miniBatchSize, int exampleLength, int numEpochs, i
205218
printWriter.println(melodies.get(i));
206219
}
207220
printWriter.close();
208-
double seconds = 0.001 * (System.currentTimeMillis() - startTime);
209221

210-
System.out.println("\n\nExample complete in " + seconds + " seconds");
211222
System.exit(0);
212223
}
213224

0 commit comments

Comments
 (0)