@@ -250,8 +250,8 @@ def to_fp8_col_major_t(
250
250
block_col_offs [:, None ] * output_stride_row
251
251
+ block_row_offs [None , :] * output_stride_col
252
252
)
253
- out_mask = (block_row_offs [:, None ] < output_num_rows ) & (
254
- block_col_offs [None , :] < output_num_cols
253
+ out_mask = (block_col_offs [:, None ] < output_num_rows ) & (
254
+ block_row_offs [None , :] < output_num_cols
255
255
)
256
256
tl .store (out_ptr + out_offs , fp8_vals , mask = out_mask )
257
257
@@ -381,6 +381,77 @@ def _to_fp8_row_major_t_and_non_t(
381
381
tl .store (row_major_t_out_ptr + row_major_t_offs , fp8_vals .trans (1 , 0 ), mask = mask )
382
382
383
383
384
+ @triton .autotune (configs = kernel_configs_2D , key = ["num_elements" ])
385
+ @triton .jit
386
+ def _to_fp8_col_major_t_and_non_t (
387
+ input_ptr ,
388
+ col_major_out_ptr ,
389
+ col_major_t_out_ptr ,
390
+ scale_ptr ,
391
+ num_elements : int ,
392
+ fp8_dtype_min : float ,
393
+ fp8_dtype_max : float ,
394
+ input_num_rows : int ,
395
+ input_num_cols : int ,
396
+ input_stride_row : int ,
397
+ input_stride_col : int ,
398
+ col_major_out_stride_row : int ,
399
+ col_major_out_stride_col : int ,
400
+ col_major_t_out_stride_row : int ,
401
+ col_major_t_out_stride_col : int ,
402
+ input_dtype : tl .constexpr ,
403
+ output_dtype : tl .constexpr ,
404
+ BLOCK_SIZE_ROWS : tl .constexpr ,
405
+ BLOCK_SIZE_COLS : tl .constexpr ,
406
+ EPS : tl .constexpr ,
407
+ ):
408
+ """
409
+ Reads a row-major, high precision input tensor and writes 2 output tensors:
410
+ 1) fp8 col major tensor (transposed)
411
+ 2) fp8 col major tensor
412
+ """
413
+ # col major tranposed
414
+ block_row_id = tl .program_id (axis = 0 )
415
+ block_col_id = tl .program_id (axis = 1 )
416
+
417
+ # load scaling factor
418
+ scale = tl .load (scale_ptr ).to (tl .float32 )
419
+
420
+ # load block of input tensor
421
+ block_row_start = block_row_id * BLOCK_SIZE_ROWS
422
+ block_col_start = block_col_id * BLOCK_SIZE_COLS
423
+ block_row_offs = block_row_start + tl .arange (0 , BLOCK_SIZE_ROWS )
424
+ block_col_offs = block_col_start + tl .arange (0 , BLOCK_SIZE_COLS )
425
+ input_offs = (
426
+ block_row_offs [:, None ] * input_stride_row
427
+ + block_col_offs [None , :] * input_stride_col
428
+ )
429
+ mask = (block_row_offs [:, None ] < input_num_rows ) & (
430
+ block_col_offs [None , :] < input_num_cols
431
+ )
432
+ vals = tl .load (input_ptr + input_offs , mask = mask ).to (input_dtype )
433
+
434
+ # perform conversion
435
+ vals = vals * scale
436
+ fp8_vals = tl .clamp (vals , min = fp8_dtype_min , max = fp8_dtype_max ).to (output_dtype )
437
+
438
+ # 1. write col-major output
439
+ out_offs = block_row_offs [:, None ] + block_col_offs [None , :] * input_num_rows
440
+ tl .store (col_major_out_ptr + out_offs , fp8_vals , mask = mask )
441
+
442
+ # 2. write tranposed col-major output
443
+ col_major_t_num_rows = input_num_cols
444
+ col_major_t_num_cols = input_num_rows
445
+ out_offs = (
446
+ block_col_offs [:, None ] * col_major_t_out_stride_row
447
+ + block_row_offs [None , :] * col_major_t_out_stride_col
448
+ )
449
+ out_mask = (block_col_offs [:, None ] < col_major_t_num_rows ) & (
450
+ block_row_offs [None , :] < col_major_t_num_cols
451
+ )
452
+ tl .store (col_major_t_out_ptr + out_offs , fp8_vals .trans (1 , 0 ), mask = out_mask )
453
+
454
+
384
455
@triton .autotune (configs = kernel_configs_1D , key = ["num_elements" ])
385
456
@triton .jit
386
457
def _amax_atomic (
@@ -859,6 +930,93 @@ def hp_to_fp8_row_major_t_and_non_t(
859
930
return fp8_tensor_row_major , fp8_tensor_row_major_t
860
931
861
932
933
+ def hp_to_fp8_col_major_t_and_non_t (
934
+ hp_tensor : torch .Tensor ,
935
+ fp8_dtype : torch .dtype ,
936
+ linear_mm_config : LinearMMConfig ,
937
+ gemm_input_role : GemmInputRole = GemmInputRole .INPUT ,
938
+ algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ,
939
+ ) -> Float8Tensor :
940
+ assert hp_tensor .is_contiguous (), "input tensor must be contiguous"
941
+
942
+ tl_input_dtype = FP8_DTYPE_MAP [hp_tensor .dtype ]
943
+ tl_output_dtype = FP8_DTYPE_MAP [fp8_dtype ]
944
+
945
+ fp8_dtype_min = torch .finfo (fp8_dtype ).min
946
+ fp8_dtype_max = torch .finfo (fp8_dtype ).max
947
+
948
+ # compute scaling factor for tensor
949
+ scale = _hp_tensor_to_scale (
950
+ hp_tensor ,
951
+ tl_input_dtype ,
952
+ fp8_dtype_max ,
953
+ algo ,
954
+ )
955
+
956
+ # perform fp8 conversion
957
+ input_num_rows , input_num_cols = hp_tensor .shape
958
+ num_elements = hp_tensor .numel ()
959
+
960
+ # preallocate necessary output tensors
961
+ fp8_output_col_major = torch .empty (
962
+ (input_num_rows , input_num_cols ), dtype = fp8_dtype , device = hp_tensor .device
963
+ )
964
+ fp8_output_col_major_t = torch .empty_like (
965
+ hp_tensor .t (),
966
+ dtype = fp8_dtype ,
967
+ device = hp_tensor .device ,
968
+ )
969
+
970
+ # launch triton kernel to perform conversion
971
+ grid = lambda meta : (
972
+ triton .cdiv (input_num_rows , meta ["BLOCK_SIZE_ROWS" ]),
973
+ triton .cdiv (input_num_cols , meta ["BLOCK_SIZE_COLS" ]),
974
+ )
975
+ _to_fp8_col_major_t_and_non_t [grid ](
976
+ hp_tensor ,
977
+ fp8_output_col_major ,
978
+ fp8_output_col_major_t ,
979
+ scale ,
980
+ num_elements ,
981
+ fp8_dtype_min ,
982
+ fp8_dtype_max ,
983
+ input_num_rows ,
984
+ input_num_cols ,
985
+ hp_tensor .stride (0 ),
986
+ hp_tensor .stride (1 ),
987
+ fp8_output_col_major .stride (0 ),
988
+ fp8_output_col_major .stride (1 ),
989
+ fp8_output_col_major_t .stride (0 ),
990
+ fp8_output_col_major_t .stride (1 ),
991
+ input_dtype = tl_input_dtype ,
992
+ output_dtype = tl_output_dtype ,
993
+ EPS = EPS ,
994
+ )
995
+
996
+ # for col major we need to update the strides to reflect the new memory layout
997
+ col_major_strides = (1 , input_num_rows )
998
+ fp8_output_col_major = fp8_output_col_major .as_strided (
999
+ fp8_output_col_major .size (), col_major_strides
1000
+ )
1001
+
1002
+ # wrap outputs in Float8Tensors
1003
+ fp8_tensor_col_major = Float8Tensor (
1004
+ fp8_output_col_major ,
1005
+ scale ,
1006
+ orig_dtype = hp_tensor .dtype ,
1007
+ linear_mm_config = linear_mm_config ,
1008
+ gemm_input_role = gemm_input_role ,
1009
+ )
1010
+ fp8_tensor_col_major_t = Float8Tensor (
1011
+ fp8_output_col_major_t ,
1012
+ scale ,
1013
+ orig_dtype = hp_tensor .dtype ,
1014
+ linear_mm_config = linear_mm_config ,
1015
+ gemm_input_role = gemm_input_role ,
1016
+ )
1017
+ return fp8_tensor_col_major , fp8_tensor_col_major_t
1018
+
1019
+
862
1020
def _hp_tensor_to_scale (
863
1021
hp_tensor : torch .Tensor ,
864
1022
tl_input_dtype : tl .core .dtype ,
0 commit comments