Skip to content

Commit d174562

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS][BE][EZ] Aggregate macros (pytorch#148187)
Refactor `INSTANTIATE_UPSAMPLE_BILINEAR2D(DTYPE)`, `INSTANTIATE_UPSAMPLE_BICUBIC2D(DTYPE)` and `INSTANTIATE_UPSAMPLE_BILINEAR2DAA(DTYPE)` use common `INSTANTIATE_UPSAMPLE2D` Then combine multiple invocations into `INSTANTIATE_UPSAMPLE_ALL` I.e. functionally it's a no-op, but achieves the same with fewer lines of code Pull Request resolved: pytorch#148187 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#148154
1 parent 4995e05 commit d174562

File tree

1 file changed

+36
-67
lines changed

1 file changed

+36
-67
lines changed

aten/src/ATen/native/mps/kernels/UpSample.metal

Lines changed: 36 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -443,57 +443,31 @@ kernel void upsample_bicubic2d_backward(
443443
}
444444
}
445445

446-
#define INSTANTIATE_UPSAMPLE_BICUBIC(DTYPE) \
447-
template [[host_name("upsample_bicubic2d_" #DTYPE)]] kernel void \
448-
upsample_bicubic2d<DTYPE>( \
449-
constant DTYPE * inputData [[buffer(0)]], \
450-
device DTYPE * outputData [[buffer(1)]], \
451-
constant ulong4 & input_strides [[buffer(2)]], \
452-
constant ulong4 & output_strides [[buffer(3)]], \
453-
constant long4 & input_sizes [[buffer(4)]], \
454-
constant long4 & output_sizes [[buffer(5)]], \
455-
constant float2 & scales [[buffer(6)]], \
456-
constant bool& align_corners [[buffer(7)]], \
457-
uint thread_index [[thread_position_in_grid]])
458-
459-
#define INSTANTIATE_UPSAMPLE_BILINEAR(DTYPE) \
460-
template [[host_name("upsample_bilinear2d_" #DTYPE)]] kernel void \
461-
upsample_bilinear2d<DTYPE>( \
462-
constant DTYPE * inputData [[buffer(0)]], \
463-
device DTYPE * outputData [[buffer(1)]], \
464-
constant ulong4 & input_strides [[buffer(2)]], \
465-
constant ulong4 & output_strides [[buffer(3)]], \
466-
constant long4 & input_sizes [[buffer(4)]], \
467-
constant long4 & output_sizes [[buffer(5)]], \
468-
constant float2 & scales [[buffer(6)]], \
469-
constant bool& align_corners [[buffer(7)]], \
470-
uint thread_index [[thread_position_in_grid]])
471-
472-
#define INSTANTIATE_UPSAMPLE_BILINEAR_AA(DTYPE) \
473-
template [[host_name("upsample_bilinear2d_aa_" #DTYPE)]] kernel void \
474-
upsample_bilinear2d_aa<DTYPE>( \
475-
constant DTYPE * inputData [[buffer(0)]], \
476-
device DTYPE * outputData [[buffer(1)]], \
477-
constant ulong4 & input_strides [[buffer(2)]], \
478-
constant ulong4 & output_strides [[buffer(3)]], \
479-
constant long4 & input_sizes [[buffer(4)]], \
480-
constant long4 & output_sizes [[buffer(5)]], \
481-
constant float2 & scales [[buffer(6)]], \
482-
constant bool& align_corners [[buffer(7)]], \
483-
uint thread_index [[thread_position_in_grid]])
484-
485-
#define INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(DTYPE) \
486-
template [[host_name("upsample_bicubic2d_backward_" #DTYPE)]] kernel void \
487-
upsample_bicubic2d_backward<DTYPE>( \
488-
device AtomicType_t<DTYPE> * gradInputData [[buffer(0)]], \
489-
constant DTYPE * gradOutputData [[buffer(1)]], \
490-
constant ulong4 & input_strides [[buffer(2)]], \
491-
constant ulong4 & output_strides [[buffer(3)]], \
492-
constant long4 & input_sizes [[buffer(4)]], \
493-
constant long4 & output_sizes [[buffer(5)]], \
494-
constant float2 & scales [[buffer(6)]], \
495-
constant bool& align_corners [[buffer(7)]], \
496-
uint thread_index [[thread_position_in_grid]])
446+
#define INSTANTIATE_UPSAMPLE_2D(NAME, DTYPE) \
447+
template [[host_name("upsample_" #NAME "_" #DTYPE)]] kernel void \
448+
upsample_##NAME<DTYPE>( \
449+
constant DTYPE * inputData [[buffer(0)]], \
450+
device DTYPE * outputData [[buffer(1)]], \
451+
constant ulong4 & input_strides [[buffer(2)]], \
452+
constant ulong4 & output_strides [[buffer(3)]], \
453+
constant long4 & input_sizes [[buffer(4)]], \
454+
constant long4 & output_sizes [[buffer(5)]], \
455+
constant float2 & scales [[buffer(6)]], \
456+
constant bool& align_corners [[buffer(7)]], \
457+
uint thread_index [[thread_position_in_grid]])
458+
459+
#define INSTANTIATE_UPSAMPLE_2D_BACKWARD(NAME, DTYPE) \
460+
template [[host_name("upsample_" #NAME "_backward_" #DTYPE)]] kernel void \
461+
upsample_##NAME##_backward<DTYPE>( \
462+
device AtomicType_t<DTYPE> * gradInputData [[buffer(0)]], \
463+
constant DTYPE * gradOutputData [[buffer(1)]], \
464+
constant ulong4 & input_strides [[buffer(2)]], \
465+
constant ulong4 & output_strides [[buffer(3)]], \
466+
constant long4 & input_sizes [[buffer(4)]], \
467+
constant long4 & output_sizes [[buffer(5)]], \
468+
constant float2 & scales [[buffer(6)]], \
469+
constant bool& align_corners [[buffer(7)]], \
470+
uint thread_index [[thread_position_in_grid]])
497471

498472
#define INSTANTIATE_UPSAMPLE_LINEAR(DTYPE) \
499473
template [[host_name("upsample_linear1d_" #DTYPE)]] kernel void \
@@ -508,21 +482,16 @@ kernel void upsample_bicubic2d_backward(
508482
constant bool& align_corners [[buffer(7)]], \
509483
uint thread_index [[thread_position_in_grid]])
510484

511-
INSTANTIATE_UPSAMPLE_BILINEAR(uchar);
512-
INSTANTIATE_UPSAMPLE_BICUBIC(float);
513-
INSTANTIATE_UPSAMPLE_BILINEAR(float);
514-
INSTANTIATE_UPSAMPLE_BILINEAR_AA(float);
515-
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(float);
516-
INSTANTIATE_UPSAMPLE_BICUBIC(half);
517-
INSTANTIATE_UPSAMPLE_BILINEAR(half);
518-
INSTANTIATE_UPSAMPLE_BILINEAR_AA(half);
519-
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(half);
520-
INSTANTIATE_UPSAMPLE_LINEAR(float);
521-
INSTANTIATE_UPSAMPLE_LINEAR(half);
485+
#define INSTANTIATE_UPSAMPLE_ALL(DTYPE) \
486+
INSTANTIATE_UPSAMPLE_2D(bicubic2d, DTYPE); \
487+
INSTANTIATE_UPSAMPLE_2D_BACKWARD(bicubic2d, DTYPE); \
488+
INSTANTIATE_UPSAMPLE_2D(bilinear2d, DTYPE); \
489+
INSTANTIATE_UPSAMPLE_2D(bilinear2d_aa, DTYPE); \
490+
INSTANTIATE_UPSAMPLE_LINEAR(DTYPE);
491+
492+
INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar);
493+
INSTANTIATE_UPSAMPLE_ALL(float);
494+
INSTANTIATE_UPSAMPLE_ALL(half);
522495
#if __METAL_VERSION__ >= 310
523-
INSTANTIATE_UPSAMPLE_BICUBIC(bfloat);
524-
INSTANTIATE_UPSAMPLE_BILINEAR(bfloat);
525-
INSTANTIATE_UPSAMPLE_BILINEAR_AA(bfloat);
526-
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(bfloat);
527-
INSTANTIATE_UPSAMPLE_LINEAR(bfloat);
496+
INSTANTIATE_UPSAMPLE_ALL(bfloat);
528497
#endif

0 commit comments

Comments
 (0)