@@ -55,7 +55,7 @@ public class GenerateTxtCharCompGraphModel {
55
55
56
56
@ SuppressWarnings ("ConstantConditions" )
57
57
public static void main (String [] args ) throws Exception {
58
- int lstmLayerSize = 200 ; //Number of units in each LSTM layer
58
+ int lstmLayerSize = 77 ; //Number of units in each LSTM layer
59
59
int miniBatchSize = 32 ; //Size of mini batch to use when training
60
60
int exampleLength = 1000 ; //Length of each training example sequence to use. This could certainly be increased
61
61
int tbpttLength = 50 ; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters
@@ -90,18 +90,20 @@ public static void main(String[] args ) throws Exception {
90
90
//Output layer, name "outputlayer" with inputs from the two layers called "first" and "second"
91
91
.addLayer ("outputLayer" , new RnnOutputLayer .Builder (LossFunctions .LossFunction .MCXENT )
92
92
.activation (Activation .SOFTMAX )
93
- .nIn (2 * lstmLayerSize ).nOut (nOut ).build (), "first" ,"second" )
93
+ .nIn (lstmLayerSize ).nOut (lstmLayerSize ).build (),"second" )
94
94
.setOutputs ("outputLayer" ) //List the output. For a ComputationGraph with multiple outputs, this also defines the input array orders
95
- .backpropType (BackpropType .TruncatedBPTT ).tBPTTForwardLength (tbpttLength ).tBPTTBackwardLength (tbpttLength )
95
+ .backpropType (BackpropType .TruncatedBPTT )
96
+ .tBPTTForwardLength (tbpttLength ).tBPTTBackwardLength (tbpttLength )
96
97
.build ();
97
98
98
99
ComputationGraph net = new ComputationGraph (conf );
99
100
net .init ();
100
101
net .setListeners (new ScoreIterationListener (1 ));
102
+ System .out .println (net .summary ());
101
103
102
104
//Print the number of parameters in the network (and for each layer)
103
105
long totalNumParams = 0 ;
104
- for ( int i = 0 ; i < net .getNumLayers (); i ++ ) {
106
+ for ( int i = 0 ; i < net .getNumLayers (); i ++) {
105
107
long nParams = net .getLayer (i ).numParams ();
106
108
System .out .println ("Number of parameters in layer " + i + ": " + nParams );
107
109
totalNumParams += nParams ;
@@ -110,16 +112,18 @@ public static void main(String[] args ) throws Exception {
110
112
111
113
//Do training, and then generate and print samples from network
112
114
int miniBatchNumber = 0 ;
113
- for ( int i = 0 ; i < numEpochs ; i ++ ) {
115
+ for ( int i = 0 ; i < numEpochs ; i ++) {
114
116
while (iter .hasNext ()){
115
117
DataSet ds = iter .next ();
118
+ System .out .println ("Input shape " + ds .getFeatures ().shapeInfoToString ());
119
+ System .out .println ("Labels " + ds .getLabels ().shapeInfoToString ());
116
120
net .fit (ds );
117
121
if (++miniBatchNumber % generateSamplesEveryNMinibatches == 0 ){
118
122
System .out .println ("--------------------" );
119
123
System .out .println ("Completed " + miniBatchNumber + " minibatches of size " + miniBatchSize + "x" + exampleLength + " characters" );
120
124
System .out .println ("Sampling characters from network given initialization \" " + (generationInitialization == null ? "" : generationInitialization ) + "\" " );
121
125
String [] samples = sampleCharactersFromNetwork (generationInitialization ,net ,iter ,rng ,nCharactersToSample ,nSamplesToGenerate );
122
- for ( int j = 0 ; j < samples .length ; j ++ ) {
126
+ for ( int j = 0 ; j < samples .length ; j ++) {
123
127
System .out .println ("----- Sample " + j + " -----" );
124
128
System .out .println (samples [j ]);
125
129
System .out .println ();
@@ -135,7 +139,7 @@ public static void main(String[] args ) throws Exception {
135
139
136
140
/** Generate a sample from the network, given an (optional, possibly null) initialization. Initialization
137
141
* can be used to 'prime' the RNN with a sequence you want to extend/continue.<br>
138
- * Note that the initalization is used for all samples
142
+ * Note that the initialization is used for all samples
139
143
* @param initialization String, may be null. If null, select a random character as initialization for all samples
140
144
* @param charactersToSample Number of characters to sample from network (excluding initialization)
141
145
* @param net MultiLayerNetwork with one or more LSTM/RNN layers and a softmax output layer
@@ -151,9 +155,9 @@ private static String[] sampleCharactersFromNetwork( String initialization, Comp
151
155
//Create input for initialization
152
156
INDArray initializationInput = Nd4j .zeros (numSamples , iter .inputColumns (), initialization .length ());
153
157
char [] init = initialization .toCharArray ();
154
- for ( int i =0 ; i <init .length ; i ++ ) {
158
+ for ( int i =0 ; i <init .length ; i ++) {
155
159
int idx = iter .convertCharacterToIndex (init [i ]);
156
- for ( int j = 0 ; j <numSamples ; j ++ ){
160
+ for ( int j = 0 ; j <numSamples ; j ++ ){
157
161
initializationInput .putScalar (new int []{j ,idx ,i }, 1.0f );
158
162
}
159
163
}
@@ -167,13 +171,13 @@ private static String[] sampleCharactersFromNetwork( String initialization, Comp
167
171
INDArray output = net .rnnTimeStep (initializationInput )[0 ];
168
172
output = output .tensorAlongDimension ((int )output .size (2 )-1 ,1 ,0 ); //Gets the last time step output
169
173
170
- for ( int i = 0 ; i < charactersToSample ; i ++ ){
174
+ for ( int i = 0 ; i < charactersToSample ; i ++ ){
171
175
//Set up next input (single time step) by sampling from previous output
172
176
INDArray nextInput = Nd4j .zeros (numSamples ,iter .inputColumns ());
173
177
//Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input
174
178
for ( int s =0 ; s <numSamples ; s ++ ){
175
179
double [] outputProbDistribution = new double [iter .totalOutcomes ()];
176
- for ( int j = 0 ; j < outputProbDistribution .length ; j ++ ) outputProbDistribution [j ] = output .getDouble (s ,j );
180
+ for ( int j = 0 ; j < outputProbDistribution .length ; j ++) outputProbDistribution [j ] = output .getDouble (s ,j );
177
181
int sampledCharacterIdx = GenerateTxtModel .sampleFromDistribution (outputProbDistribution ,rng );
178
182
179
183
nextInput .putScalar (new int []{s ,sampledCharacterIdx }, 1.0f ); //Prepare next time step input
0 commit comments