@@ -177,9 +177,9 @@ NNModule THSNN_ReLU_ctor(bool inplace,
177
177
// );
178
178
}
179
179
180
- Tensor THSNN_ReLU_forward (const NNAnyModule module, const Tensor tensor)
180
+ Tensor THSNN_ReLU_forward (const NNModule module, const Tensor tensor)
181
181
{
182
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::ReLU>()->forward (*tensor));
182
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::ReLU>()->forward (*tensor));
183
183
}
184
184
185
185
NNModule THSNN_Dropout_ctor (double probability,
@@ -197,9 +197,9 @@ NNModule THSNN_Dropout_ctor(double probability,
197
197
);
198
198
}
199
199
200
- Tensor THSNN_Dropout_forward (const NNAnyModule module, const Tensor tensor)
200
+ Tensor THSNN_Dropout_forward (const NNModule module, const Tensor tensor)
201
201
{
202
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::Dropout>()->forward (*tensor));
202
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::Dropout>()->forward (*tensor));
203
203
}
204
204
205
205
NNModule THSNN_FeatureAlphaDropout_ctor (double probability,
@@ -216,9 +216,9 @@ NNModule THSNN_FeatureAlphaDropout_ctor(double probability,
216
216
);
217
217
}
218
218
219
- Tensor THSNN_FeatureAlphaDropout_forward (const NNAnyModule module, const Tensor tensor)
219
+ Tensor THSNN_FeatureAlphaDropout_forward (const NNModule module, const Tensor tensor)
220
220
{
221
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::FeatureAlphaDropout>()->forward (*tensor));
221
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::FeatureAlphaDropout>()->forward (*tensor));
222
222
}
223
223
224
224
NNModule THSNN_LogSoftMax_ctor (int64_t dim,
@@ -239,9 +239,9 @@ NNModule THSNN_LogSoftMax_ctor(int64_t dim,
239
239
);
240
240
}
241
241
242
- Tensor THSNN_LogSoftMax_forward (const NNAnyModule module, const Tensor tensor)
242
+ Tensor THSNN_LogSoftMax_forward (const NNModule module, const Tensor tensor)
243
243
{
244
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::LogSoftmax>()->forward (*tensor));
244
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::LogSoftmax>()->forward (*tensor));
245
245
}
246
246
247
247
NNModule THSNN_AvgPool2d_ctor (const int64_t * kernelSize, const int kernelSizeLength, const int64_t * stride, const int strideLength,
@@ -264,9 +264,9 @@ NNModule THSNN_AvgPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLen
264
264
);
265
265
}
266
266
267
- Tensor THSNN_AvgPool2d_forward (const NNAnyModule module, const Tensor tensor)
267
+ Tensor THSNN_AvgPool2d_forward (const NNModule module, const Tensor tensor)
268
268
{
269
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::AvgPool2d>()->forward (*tensor));
269
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::AvgPool2d>()->forward (*tensor));
270
270
}
271
271
272
272
NNModule THSNN_AdaptiveAvgPool2d_ctor (const int64_t * kernelSize, const int kernelSizeLength,
@@ -281,9 +281,9 @@ NNModule THSNN_AdaptiveAvgPool2d_ctor(const int64_t* kernelSize, const int kerne
281
281
);
282
282
}
283
283
284
- Tensor THSNN_AdaptiveAvgPool2d_forward (const NNAnyModule module, const Tensor tensor)
284
+ Tensor THSNN_AdaptiveAvgPool2d_forward (const NNModule module, const Tensor tensor)
285
285
{
286
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::AdaptiveAvgPool2d>()->forward (*tensor));
286
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::AdaptiveAvgPool2d>()->forward (*tensor));
287
287
}
288
288
289
289
NNModule THSNN_MaxPool2d_ctor (const int64_t * kernelSize, const int kernelSizeLength, const int64_t * stride, const int strideLength,
@@ -304,9 +304,9 @@ NNModule THSNN_MaxPool2d_ctor(const int64_t* kernelSize, const int kernelSizeLen
304
304
)
305
305
}
306
306
307
- Tensor THSNN_MaxPool2d_forward (const NNAnyModule module, const Tensor tensor)
307
+ Tensor THSNN_MaxPool2d_forward (const NNModule module, const Tensor tensor)
308
308
{
309
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::MaxPool2d>()->forward (*tensor));
309
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::MaxPool2d>()->forward (*tensor));
310
310
}
311
311
312
312
NNModule THSNN_Linear_ctor (const int64_t input_size, const int64_t output_size, const bool bias,
@@ -327,32 +327,32 @@ NNModule THSNN_Linear_ctor(const int64_t input_size, const int64_t output_size,
327
327
);
328
328
}
329
329
330
- Tensor THSNN_Linear_forward (const NNAnyModule module, const Tensor tensor)
330
+ Tensor THSNN_Linear_forward (const NNModule module, const Tensor tensor)
331
331
{
332
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::Linear>()->forward (*tensor));
332
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::Linear>()->forward (*tensor));
333
333
}
334
334
335
- Tensor THSNN_Linear_bias (const NNAnyModule module)
335
+ Tensor THSNN_Linear_bias (const NNModule module)
336
336
{
337
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::Linear>()->bias );
337
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::Linear>()->bias );
338
338
}
339
339
340
- void THSNN_Linear_set_bias (const NNAnyModule module, const Tensor bias)
340
+ void THSNN_Linear_set_bias (const NNModule module, const Tensor bias)
341
341
{
342
342
CATCH (
343
- (*module)->get <torch::nn::Linear>()->bias = *bias;
343
+ (*module)->as <torch::nn::Linear>()->bias = *bias;
344
344
)
345
345
}
346
346
347
- Tensor THSNN_Linear_weight (const NNAnyModule module)
347
+ Tensor THSNN_Linear_weight (const NNModule module)
348
348
{
349
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::Linear>()->weight );
349
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::Linear>()->weight );
350
350
}
351
351
352
- void THSNN_Linear_set_weight (const NNAnyModule module, const Tensor weight)
352
+ void THSNN_Linear_set_weight (const NNModule module, const Tensor weight)
353
353
{
354
354
CATCH (
355
- (*module)->get <torch::nn::Linear>()->weight = *weight;
355
+ (*module)->as <torch::nn::Linear>()->weight = *weight;
356
356
)
357
357
}
358
358
@@ -376,9 +376,9 @@ NNModule THSNN_Conv2d_ctor(const int64_t inputChannel, const int64_t outputChann
376
376
);
377
377
}
378
378
379
- Tensor THSNN_Conv2d_forward (const NNAnyModule module, const Tensor tensor)
379
+ Tensor THSNN_Conv2d_forward (const NNModule module, const Tensor tensor)
380
380
{
381
- CATCH_RETURN_TENSOR ((*module)->get <torch::nn::Conv2d>()->forward (*tensor));
381
+ CATCH_RETURN_TENSOR ((*module)->as <torch::nn::Conv2d>()->forward (*tensor));
382
382
}
383
383
384
384
NNSequential THSNN_Sequential_ctor ()
0 commit comments