Skip to content

Commit 96e51a9

Browse files
authored
Merge pull request #1069 from deeplearning4j/ag_fix_1063
Fix up generate text computation graph
2 parents 686db99 + 4838e3e commit 96e51a9

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/charmodelling/generatetext/GenerateTxtCharCompGraphModel.java

+15-11
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public class GenerateTxtCharCompGraphModel {
5555

5656
@SuppressWarnings("ConstantConditions")
5757
public static void main(String[] args ) throws Exception {
58-
int lstmLayerSize = 200; //Number of units in each LSTM layer
58+
int lstmLayerSize = 77; //Number of units in each LSTM layer
5959
int miniBatchSize = 32; //Size of mini batch to use when training
6060
int exampleLength = 1000; //Length of each training example sequence to use. This could certainly be increased
6161
int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters
@@ -90,18 +90,20 @@ public static void main(String[] args ) throws Exception {
9090
//Output layer, name "outputlayer" with inputs from the two layers called "first" and "second"
9191
.addLayer("outputLayer", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
9292
.activation(Activation.SOFTMAX)
93-
.nIn(2*lstmLayerSize).nOut(nOut).build(),"first","second")
93+
.nIn(lstmLayerSize).nOut(lstmLayerSize).build(),"second")
9494
.setOutputs("outputLayer") //List the output. For a ComputationGraph with multiple outputs, this also defines the input array orders
95-
.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)
95+
.backpropType(BackpropType.TruncatedBPTT)
96+
.tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)
9697
.build();
9798

9899
ComputationGraph net = new ComputationGraph(conf);
99100
net.init();
100101
net.setListeners(new ScoreIterationListener(1));
102+
System.out.println(net.summary());
101103

102104
//Print the number of parameters in the network (and for each layer)
103105
long totalNumParams = 0;
104-
for( int i=0; i<net.getNumLayers(); i++ ){
106+
for( int i = 0; i < net.getNumLayers(); i++) {
105107
long nParams = net.getLayer(i).numParams();
106108
System.out.println("Number of parameters in layer " + i + ": " + nParams);
107109
totalNumParams += nParams;
@@ -110,16 +112,18 @@ public static void main(String[] args ) throws Exception {
110112

111113
//Do training, and then generate and print samples from network
112114
int miniBatchNumber = 0;
113-
for( int i=0; i<numEpochs; i++ ){
115+
for( int i = 0; i < numEpochs; i++) {
114116
while(iter.hasNext()){
115117
DataSet ds = iter.next();
118+
System.out.println("Input shape " + ds.getFeatures().shapeInfoToString());
119+
System.out.println("Labels " + ds.getLabels().shapeInfoToString());
116120
net.fit(ds);
117121
if(++miniBatchNumber % generateSamplesEveryNMinibatches == 0){
118122
System.out.println("--------------------");
119123
System.out.println("Completed " + miniBatchNumber + " minibatches of size " + miniBatchSize + "x" + exampleLength + " characters" );
120124
System.out.println("Sampling characters from network given initialization \"" + (generationInitialization == null ? "" : generationInitialization) + "\"");
121125
String[] samples = sampleCharactersFromNetwork(generationInitialization,net,iter,rng,nCharactersToSample,nSamplesToGenerate);
122-
for( int j=0; j<samples.length; j++ ){
126+
for( int j = 0; j < samples.length; j++) {
123127
System.out.println("----- Sample " + j + " -----");
124128
System.out.println(samples[j]);
125129
System.out.println();
@@ -135,7 +139,7 @@ public static void main(String[] args ) throws Exception {
135139

136140
/** Generate a sample from the network, given an (optional, possibly null) initialization. Initialization
137141
* can be used to 'prime' the RNN with a sequence you want to extend/continue.<br>
138-
* Note that the initalization is used for all samples
142+
* Note that the initialization is used for all samples
139143
* @param initialization String, may be null. If null, select a random character as initialization for all samples
140144
* @param charactersToSample Number of characters to sample from network (excluding initialization)
141145
* @param net MultiLayerNetwork with one or more LSTM/RNN layers and a softmax output layer
@@ -151,9 +155,9 @@ private static String[] sampleCharactersFromNetwork( String initialization, Comp
151155
//Create input for initialization
152156
INDArray initializationInput = Nd4j.zeros(numSamples, iter.inputColumns(), initialization.length());
153157
char[] init = initialization.toCharArray();
154-
for( int i=0; i<init.length; i++ ){
158+
for( int i=0; i<init.length; i++) {
155159
int idx = iter.convertCharacterToIndex(init[i]);
156-
for( int j=0; j<numSamples; j++ ){
160+
for( int j = 0; j<numSamples; j++ ){
157161
initializationInput.putScalar(new int[]{j,idx,i}, 1.0f);
158162
}
159163
}
@@ -167,13 +171,13 @@ private static String[] sampleCharactersFromNetwork( String initialization, Comp
167171
INDArray output = net.rnnTimeStep(initializationInput)[0];
168172
output = output.tensorAlongDimension((int)output.size(2)-1,1,0); //Gets the last time step output
169173

170-
for( int i=0; i<charactersToSample; i++ ){
174+
for( int i = 0; i < charactersToSample; i++ ){
171175
//Set up next input (single time step) by sampling from previous output
172176
INDArray nextInput = Nd4j.zeros(numSamples,iter.inputColumns());
173177
//Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input
174178
for( int s=0; s<numSamples; s++ ){
175179
double[] outputProbDistribution = new double[iter.totalOutcomes()];
176-
for( int j=0; j<outputProbDistribution.length; j++ ) outputProbDistribution[j] = output.getDouble(s,j);
180+
for( int j = 0; j < outputProbDistribution.length; j++) outputProbDistribution[j] = output.getDouble(s,j);
177181
int sampledCharacterIdx = GenerateTxtModel.sampleFromDistribution(outputProbDistribution,rng);
178182

179183
nextInput.putScalar(new int[]{s,sampledCharacterIdx}, 1.0f); //Prepare next time step input

0 commit comments

Comments
 (0)