|
242 | 242 | <div class="pytorch-left-menu-search">
|
243 | 243 |
|
244 | 244 | <div class="version">
|
245 |
| - <a href='https://pytorch.org/docs/versions.html'>main (2.3.0a0+gitf76e541 ) ▼</a> |
| 245 | + <a href='https://pytorch.org/docs/versions.html'>main (2.3.0a0+gitca96784 ) ▼</a> |
246 | 246 | </div>
|
247 | 247 |
|
248 | 248 |
|
@@ -785,7 +785,7 @@ <h1>Source code for torch.ao.nn.quantizable.modules.rnn</h1><div class="highligh
|
785 | 785 | <span class="bp">self</span><span class="o">.</span><span class="n">batch_first</span> <span class="o">=</span> <span class="n">batch_first</span>
|
786 | 786 | <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>
|
787 | 787 | <span class="bp">self</span><span class="o">.</span><span class="n">bidirectional</span> <span class="o">=</span> <span class="n">bidirectional</span>
|
788 |
| - <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># We don't want to train using this module</span> |
| 788 | + <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># Default to eval mode. If we want to train, we will explicitly set to training.</span> |
789 | 789 | <span class="n">num_directions</span> <span class="o">=</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">bidirectional</span> <span class="k">else</span> <span class="mi">1</span>
|
790 | 790 |
|
791 | 791 | <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dropout</span><span class="p">,</span> <span class="n">numbers</span><span class="o">.</span><span class="n">Number</span><span class="p">)</span> <span class="ow">or</span> <span class="ow">not</span> <span class="mi">0</span> <span class="o"><=</span> <span class="n">dropout</span> <span class="o"><=</span> <span class="mi">1</span> <span class="ow">or</span> \
|
@@ -874,9 +874,14 @@ <h1>Source code for torch.ao.nn.quantizable.modules.rnn</h1><div class="highligh
|
874 | 874 | <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">other</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
|
875 | 875 | <span class="n">observed</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">_LSTMLayer</span><span class="o">.</span><span class="n">from_float</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="n">qconfig</span><span class="p">,</span>
|
876 | 876 | <span class="n">batch_first</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
877 |
| - <span class="c1"># TODO: Remove setting observed to eval to enable QAT.</span> |
878 |
| - <span class="n">observed</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> |
879 |
| - <span class="n">observed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ao</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">prepare</span><span class="p">(</span><span class="n">observed</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> |
| 877 | + |
| 878 | + <span class="c1"># Prepare the model</span> |
| 879 | + <span class="k">if</span> <span class="n">other</span><span class="o">.</span><span class="n">training</span><span class="p">:</span> |
| 880 | + <span class="n">observed</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> |
| 881 | + <span class="n">observed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ao</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">prepare_qat</span><span class="p">(</span><span class="n">observed</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> |
| 882 | + <span class="k">else</span><span class="p">:</span> |
| 883 | + <span class="n">observed</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> |
| 884 | + <span class="n">observed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ao</span><span class="o">.</span><span class="n">quantization</span><span class="o">.</span><span class="n">prepare</span><span class="p">(</span><span class="n">observed</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> |
880 | 885 | <span class="k">return</span> <span class="n">observed</span>
|
881 | 886 |
|
882 | 887 | <span class="nd">@classmethod</span>
|
|
0 commit comments