@@ -443,57 +443,31 @@ kernel void upsample_bicubic2d_backward(
443
443
}
444
444
}
445
445
446
- #define INSTANTIATE_UPSAMPLE_BICUBIC (DTYPE ) \
447
- template [[host_name(" upsample_bicubic2d_" #DTYPE)]] kernel void \
448
- upsample_bicubic2d<DTYPE>( \
449
- constant DTYPE * inputData [[buffer(0 )]], \
450
- device DTYPE * outputData [[buffer(1 )]], \
451
- constant ulong4 & input_strides [[buffer(2 )]], \
452
- constant ulong4 & output_strides [[buffer(3 )]], \
453
- constant long4 & input_sizes [[buffer(4 )]], \
454
- constant long4 & output_sizes [[buffer(5 )]], \
455
- constant float2 & scales [[buffer(6 )]], \
456
- constant bool & align_corners [[buffer(7 )]], \
457
- uint thread_index [[thread_position_in_grid]])
458
-
459
- #define INSTANTIATE_UPSAMPLE_BILINEAR (DTYPE ) \
460
- template [[host_name(" upsample_bilinear2d_" #DTYPE)]] kernel void \
461
- upsample_bilinear2d<DTYPE>( \
462
- constant DTYPE * inputData [[buffer(0 )]], \
463
- device DTYPE * outputData [[buffer(1 )]], \
464
- constant ulong4 & input_strides [[buffer(2 )]], \
465
- constant ulong4 & output_strides [[buffer(3 )]], \
466
- constant long4 & input_sizes [[buffer(4 )]], \
467
- constant long4 & output_sizes [[buffer(5 )]], \
468
- constant float2 & scales [[buffer(6 )]], \
469
- constant bool & align_corners [[buffer(7 )]], \
470
- uint thread_index [[thread_position_in_grid]])
471
-
472
- #define INSTANTIATE_UPSAMPLE_BILINEAR_AA (DTYPE ) \
473
- template [[host_name(" upsample_bilinear2d_aa_" #DTYPE)]] kernel void \
474
- upsample_bilinear2d_aa<DTYPE>( \
475
- constant DTYPE * inputData [[buffer(0 )]], \
476
- device DTYPE * outputData [[buffer(1 )]], \
477
- constant ulong4 & input_strides [[buffer(2 )]], \
478
- constant ulong4 & output_strides [[buffer(3 )]], \
479
- constant long4 & input_sizes [[buffer(4 )]], \
480
- constant long4 & output_sizes [[buffer(5 )]], \
481
- constant float2 & scales [[buffer(6 )]], \
482
- constant bool & align_corners [[buffer(7 )]], \
483
- uint thread_index [[thread_position_in_grid]])
484
-
485
- #define INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD (DTYPE ) \
486
- template [[host_name(" upsample_bicubic2d_backward_" #DTYPE)]] kernel void \
487
- upsample_bicubic2d_backward<DTYPE>( \
488
- device AtomicType_t<DTYPE> * gradInputData [[buffer(0 )]], \
489
- constant DTYPE * gradOutputData [[buffer(1 )]], \
490
- constant ulong4 & input_strides [[buffer(2 )]], \
491
- constant ulong4 & output_strides [[buffer(3 )]], \
492
- constant long4 & input_sizes [[buffer(4 )]], \
493
- constant long4 & output_sizes [[buffer(5 )]], \
494
- constant float2 & scales [[buffer(6 )]], \
495
- constant bool & align_corners [[buffer(7 )]], \
496
- uint thread_index [[thread_position_in_grid]])
446
+ #define INSTANTIATE_UPSAMPLE_2D (NAME, DTYPE ) \
447
+ template [[host_name(" upsample_" #NAME " _" #DTYPE)]] kernel void \
448
+ upsample_##NAME<DTYPE>( \
449
+ constant DTYPE * inputData [[buffer(0 )]], \
450
+ device DTYPE * outputData [[buffer(1 )]], \
451
+ constant ulong4 & input_strides [[buffer(2 )]], \
452
+ constant ulong4 & output_strides [[buffer(3 )]], \
453
+ constant long4 & input_sizes [[buffer(4 )]], \
454
+ constant long4 & output_sizes [[buffer(5 )]], \
455
+ constant float2 & scales [[buffer(6 )]], \
456
+ constant bool & align_corners [[buffer(7 )]], \
457
+ uint thread_index [[thread_position_in_grid]])
458
+
459
+ #define INSTANTIATE_UPSAMPLE_2D_BACKWARD (NAME, DTYPE ) \
460
+ template [[host_name(" upsample_" #NAME " _backward_" #DTYPE)]] kernel void \
461
+ upsample_##NAME##_backward<DTYPE>( \
462
+ device AtomicType_t<DTYPE> * gradInputData [[buffer(0 )]], \
463
+ constant DTYPE * gradOutputData [[buffer(1 )]], \
464
+ constant ulong4 & input_strides [[buffer(2 )]], \
465
+ constant ulong4 & output_strides [[buffer(3 )]], \
466
+ constant long4 & input_sizes [[buffer(4 )]], \
467
+ constant long4 & output_sizes [[buffer(5 )]], \
468
+ constant float2 & scales [[buffer(6 )]], \
469
+ constant bool & align_corners [[buffer(7 )]], \
470
+ uint thread_index [[thread_position_in_grid]])
497
471
498
472
#define INSTANTIATE_UPSAMPLE_LINEAR (DTYPE ) \
499
473
template [[host_name(" upsample_linear1d_" #DTYPE)]] kernel void \
@@ -508,21 +482,16 @@ kernel void upsample_bicubic2d_backward(
508
482
constant bool & align_corners [[buffer(7 )]], \
509
483
uint thread_index [[thread_position_in_grid]])
510
484
511
- INSTANTIATE_UPSAMPLE_BILINEAR (uchar);
512
- INSTANTIATE_UPSAMPLE_BICUBIC (float );
513
- INSTANTIATE_UPSAMPLE_BILINEAR (float );
514
- INSTANTIATE_UPSAMPLE_BILINEAR_AA (float );
515
- INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD (float );
516
- INSTANTIATE_UPSAMPLE_BICUBIC (half);
517
- INSTANTIATE_UPSAMPLE_BILINEAR (half);
518
- INSTANTIATE_UPSAMPLE_BILINEAR_AA (half);
519
- INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD (half);
520
- INSTANTIATE_UPSAMPLE_LINEAR (float );
521
- INSTANTIATE_UPSAMPLE_LINEAR (half);
485
+ #define INSTANTIATE_UPSAMPLE_ALL (DTYPE ) \
486
+ INSTANTIATE_UPSAMPLE_2D (bicubic2d, DTYPE); \
487
+ INSTANTIATE_UPSAMPLE_2D_BACKWARD (bicubic2d, DTYPE); \
488
+ INSTANTIATE_UPSAMPLE_2D (bilinear2d, DTYPE); \
489
+ INSTANTIATE_UPSAMPLE_2D (bilinear2d_aa, DTYPE); \
490
+ INSTANTIATE_UPSAMPLE_LINEAR (DTYPE);
491
+
492
+ INSTANTIATE_UPSAMPLE_2D (bilinear2d, uchar);
493
+ INSTANTIATE_UPSAMPLE_ALL (float );
494
+ INSTANTIATE_UPSAMPLE_ALL (half);
522
495
#if __METAL_VERSION__ >= 310
523
- INSTANTIATE_UPSAMPLE_BICUBIC (bfloat);
524
- INSTANTIATE_UPSAMPLE_BILINEAR (bfloat);
525
- INSTANTIATE_UPSAMPLE_BILINEAR_AA (bfloat);
526
- INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD (bfloat);
527
- INSTANTIATE_UPSAMPLE_LINEAR (bfloat);
496
+ INSTANTIATE_UPSAMPLE_ALL (bfloat);
528
497
#endif
0 commit comments