@@ -55,7 +55,7 @@ public class GenerateTxtCharCompGraphModel {
5555
5656    @ SuppressWarnings ("ConstantConditions" )
5757    public  static  void  main (String [] args  ) throws  Exception  {
58-         int  lstmLayerSize  = 200 ;					//Number of units in each LSTM layer 
58+         int  lstmLayerSize  = 77 ;					//Number of units in each LSTM layer 
5959        int  miniBatchSize  = 32 ;						//Size of mini batch to use when  training 
6060        int  exampleLength  = 1000 ;					//Length of each training example sequence to use. This could certainly be increased 
6161        int  tbpttLength  = 50 ;                       //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters 
@@ -90,18 +90,20 @@ public static void main(String[] args ) throws Exception {
9090            //Output layer, name "outputlayer" with inputs from the two layers called "first" and "second" 
9191            .addLayer ("outputLayer" , new  RnnOutputLayer .Builder (LossFunctions .LossFunction .MCXENT )
9292                .activation (Activation .SOFTMAX )
93-                 .nIn (2 * lstmLayerSize ).nOut (nOut ).build (), "first" ,"second" )
93+                 .nIn (lstmLayerSize ).nOut (lstmLayerSize ).build (),"second" )
9494            .setOutputs ("outputLayer" )  //List the output. For a ComputationGraph with multiple outputs, this also defines the input array orders 
95-             .backpropType (BackpropType .TruncatedBPTT ).tBPTTForwardLength (tbpttLength ).tBPTTBackwardLength (tbpttLength )
95+             .backpropType (BackpropType .TruncatedBPTT )
96+                 .tBPTTForwardLength (tbpttLength ).tBPTTBackwardLength (tbpttLength )
9697            .build ();
9798
9899        ComputationGraph  net  = new  ComputationGraph (conf );
99100        net .init ();
100101        net .setListeners (new  ScoreIterationListener (1 ));
102+         System .out .println (net .summary ());
101103
102104        //Print the  number of parameters in the network (and for each layer) 
103105        long  totalNumParams  = 0 ;
104-         for ( int  i = 0 ; i < net .getNumLayers (); i ++ ) {
106+         for ( int  i  =  0 ; i  <  net .getNumLayers (); i ++)  {
105107            long  nParams  = net .getLayer (i ).numParams ();
106108            System .out .println ("Number of parameters in layer "  + i  + ": "  + nParams );
107109            totalNumParams  += nParams ;
@@ -110,16 +112,18 @@ public static void main(String[] args ) throws Exception {
110112
111113        //Do training, and then generate and print samples from network 
112114        int  miniBatchNumber  = 0 ;
113-         for ( int  i = 0 ; i < numEpochs ; i ++ ) {
115+         for ( int  i  =  0 ; i  <  numEpochs ; i ++)  {
114116            while (iter .hasNext ()){
115117                DataSet  ds  = iter .next ();
118+                 System .out .println ("Input shape "  + ds .getFeatures ().shapeInfoToString ());
119+                 System .out .println ("Labels "  + ds .getLabels ().shapeInfoToString ());
116120                net .fit (ds );
117121                if (++miniBatchNumber  % generateSamplesEveryNMinibatches  == 0 ){
118122                    System .out .println ("--------------------" );
119123                    System .out .println ("Completed "  + miniBatchNumber  + " minibatches of size "  + miniBatchSize  + "x"  + exampleLength  + " characters"  );
120124                    System .out .println ("Sampling characters from network given initialization \" "  + (generationInitialization  == null  ? ""  : generationInitialization ) + "\" " );
121125                    String [] samples  = sampleCharactersFromNetwork (generationInitialization ,net ,iter ,rng ,nCharactersToSample ,nSamplesToGenerate );
122-                     for ( int  j = 0 ; j < samples .length ; j ++ ) {
126+                     for ( int  j  =  0 ; j  <  samples .length ; j ++)  {
123127                        System .out .println ("----- Sample "  + j  + " -----" );
124128                        System .out .println (samples [j ]);
125129                        System .out .println ();
@@ -135,7 +139,7 @@ public static void main(String[] args ) throws Exception {
135139
136140    /** Generate a sample from the network, given an (optional, possibly null) initialization. Initialization 
137141     * can be used to 'prime' the RNN with a sequence you want to extend/continue.<br> 
138-      * Note that the initalization  is used for all samples 
142+      * Note that the initialization  is used for all samples 
139143     * @param initialization String, may be null. If null, select a random character as initialization for all samples 
140144     * @param charactersToSample Number of characters to sample from network (excluding initialization) 
141145     * @param net MultiLayerNetwork with one or more LSTM/RNN layers and a softmax output layer 
@@ -151,9 +155,9 @@ private static String[] sampleCharactersFromNetwork( String initialization, Comp
151155        //Create input for initialization 
152156        INDArray  initializationInput  = Nd4j .zeros (numSamples , iter .inputColumns (), initialization .length ());
153157        char [] init  = initialization .toCharArray ();
154-         for ( int  i =0 ; i <init .length ; i ++ ) {
158+         for ( int  i =0 ; i <init .length ; i ++)  {
155159            int  idx  = iter .convertCharacterToIndex (init [i ]);
156-             for ( int  j = 0 ; j <numSamples ; j ++ ){
160+             for ( int  j  =  0 ; j <numSamples ; j ++ ){
157161                initializationInput .putScalar (new  int []{j ,idx ,i }, 1.0f );
158162            }
159163        }
@@ -167,13 +171,13 @@ private static String[] sampleCharactersFromNetwork( String initialization, Comp
167171        INDArray  output  = net .rnnTimeStep (initializationInput )[0 ];
168172        output  = output .tensorAlongDimension ((int )output .size (2 )-1 ,1 ,0 );	//Gets the last time step output 
169173
170-         for ( int  i = 0 ; i < charactersToSample ; i ++ ){
174+         for ( int  i  =  0 ; i  <  charactersToSample ; i ++ ){
171175            //Set up next input (single time step) by sampling from previous output 
172176            INDArray  nextInput  = Nd4j .zeros (numSamples ,iter .inputColumns ());
173177            //Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input 
174178            for ( int  s =0 ; s <numSamples ; s ++ ){
175179                double [] outputProbDistribution  = new  double [iter .totalOutcomes ()];
176-                 for ( int  j = 0 ; j < outputProbDistribution .length ; j ++  ) outputProbDistribution [j ] = output .getDouble (s ,j );
180+                 for ( int  j  =  0 ; j  <  outputProbDistribution .length ; j ++) outputProbDistribution [j ] = output .getDouble (s ,j );
177181                int  sampledCharacterIdx  = GenerateTxtModel .sampleFromDistribution (outputProbDistribution ,rng );
178182
179183                nextInput .putScalar (new  int []{s ,sampledCharacterIdx }, 1.0f );		//Prepare next time step input 
0 commit comments