@@ -179,7 +179,12 @@ Problem::FindSolutions(Handle& handle, const FindOptions& options, std::size_t m
179
179
auto ret = std::visit (
180
180
boost::hof::match (
181
181
[&](const ConvolutionDescriptor& op_desc) {
182
- return FindSolutionsImpl (handle, options, max_solutions, buffers, op_desc);
182
+ if (op_desc.mode == miopenTranspose)
183
+ return MakeTransposed ().FindSolutionsImpl (
184
+ handle, options, max_solutions, buffers, op_desc, *this );
185
+ else
186
+ return FindSolutionsImpl (
187
+ handle, options, max_solutions, buffers, op_desc, *this );
183
188
},
184
189
[&](const SoftmaxDescriptor& op_desc) {
185
190
return FindSolutionsImpl (handle, options, max_solutions, buffers, op_desc);
@@ -460,7 +465,8 @@ std::vector<Solution> Problem::FindSolutionsImpl(Handle& handle,
460
465
const FindOptions& options,
461
466
std::size_t max_solutions,
462
467
const Buffers& buffers,
463
- const ConvolutionDescriptor& conv_desc) const
468
+ const ConvolutionDescriptor& conv_desc,
469
+ const Problem& original) const
464
470
{
465
471
if (tensor_descriptors.size () != 3 )
466
472
{
@@ -477,21 +483,17 @@ std::vector<Solution> Problem::FindSolutionsImpl(Handle& handle,
477
483
const auto & w = buffers.at (miopenTensorConvolutionW);
478
484
auto y = buffers.at (miopenTensorConvolutionY);
479
485
480
- const auto conv_problem =
481
- conv_desc.mode == miopenTranspose ? MakeTransposed ().AsConvolution () : AsConvolution ();
482
-
483
- std::size_t workspace_size;
484
- Allocator::ManageDataPtr owned_workspace;
485
- Data_t workspace;
486
-
487
486
if (conv_desc.mode == miopenTranspose)
488
- {
489
487
std::swap (x, y);
490
- std::swap (x_desc, y_desc);
491
- }
488
+
489
+ const auto conv_problem = AsConvolution ();
492
490
493
491
ValidateGroupCount (x_desc, w_desc, conv_desc);
494
492
493
+ std::size_t workspace_size;
494
+ Allocator::ManageDataPtr owned_workspace;
495
+ Data_t workspace;
496
+
495
497
if (options.preallocated_workspace )
496
498
{
497
499
workspace = options.preallocated_workspace ->buffer ;
@@ -518,7 +520,7 @@ std::vector<Solution> Problem::FindSolutionsImpl(Handle& handle,
518
520
519
521
for (auto & result : results)
520
522
{
521
- result.SetProblem ({* this });
523
+ result.SetProblem ({original });
522
524
523
525
if (result.GetKernels ().empty ())
524
526
{
0 commit comments