@@ -305,6 +305,82 @@ def _to_fp8_row_and_col_major(
305
305
tl .store (col_major_out_ptr + col_major_offs , fp8_vals , mask = mask )
306
306
307
307
308
+ @triton .autotune (
309
+ configs = kernel_configs_2D ,
310
+ key = ["num_elements" ],
311
+ )
312
+ @triton .jit
313
+ def _to_fp8_row_major_t_and_non_t (
314
+ input_ptr ,
315
+ row_major_out_ptr ,
316
+ row_major_t_out_ptr ,
317
+ scale_ptr ,
318
+ num_elements : int ,
319
+ fp8_dtype_min : float ,
320
+ fp8_dtype_max : float ,
321
+ input_num_rows : int ,
322
+ input_num_cols : int ,
323
+ input_stride_row : int ,
324
+ input_stride_col : int ,
325
+ row_major_out_stride_row : int ,
326
+ row_major_out_stride_col : int ,
327
+ row_major_t_out_stride_row : int ,
328
+ row_major_t_out_stride_col : int ,
329
+ input_dtype : tl .constexpr ,
330
+ output_dtype : tl .constexpr ,
331
+ BLOCK_SIZE_ROWS : tl .constexpr ,
332
+ BLOCK_SIZE_COLS : tl .constexpr ,
333
+ EPS : tl .constexpr ,
334
+ ):
335
+ """
336
+ Reads a row-major, high precision input tensor and writes 2 output tensors:
337
+ 1) fp8 row major tensor (transposed)
338
+ 2) fp8 row major tensor
339
+ """
340
+ block_row_id = tl .program_id (axis = 0 )
341
+ block_col_id = tl .program_id (axis = 1 )
342
+
343
+ # load scaling factor
344
+ scale = tl .load (scale_ptr ).to (tl .float32 )
345
+
346
+ # load block of input tensor
347
+ block_row_start = block_row_id * BLOCK_SIZE_ROWS
348
+ block_col_start = block_col_id * BLOCK_SIZE_COLS
349
+ block_row_offs = block_row_start + tl .arange (0 , BLOCK_SIZE_ROWS )
350
+ block_col_offs = block_col_start + tl .arange (0 , BLOCK_SIZE_COLS )
351
+ input_offs = (
352
+ block_row_offs [:, None ] * input_stride_row
353
+ + block_col_offs [None , :] * input_stride_col
354
+ )
355
+ mask = (block_row_offs [:, None ] < input_num_rows ) & (
356
+ block_col_offs [None , :] < input_num_cols
357
+ )
358
+ vals = tl .load (input_ptr + input_offs , mask = mask ).to (input_dtype )
359
+
360
+ # perform conversion
361
+ vals = vals * scale
362
+ fp8_vals = tl .clamp (vals , min = fp8_dtype_min , max = fp8_dtype_max ).to (output_dtype )
363
+
364
+ # write row-major output
365
+ row_major_offs = (
366
+ block_row_offs [:, None ] * row_major_out_stride_row
367
+ + block_col_offs [None , :] * row_major_out_stride_col
368
+ )
369
+ tl .store (row_major_out_ptr + row_major_offs , fp8_vals , mask = mask )
370
+
371
+ # write tranposed row-major output
372
+ row_major_t_num_rows = input_num_cols
373
+ row_major_t_num_cols = input_num_rows
374
+ row_major_t_offs = (
375
+ block_col_offs [:, None ] * row_major_t_out_stride_row
376
+ + block_row_offs [None , :] * row_major_t_out_stride_col
377
+ )
378
+ mask = (block_row_offs [:, None ] < row_major_t_num_rows ) & (
379
+ block_col_offs [None , :] < row_major_t_num_cols
380
+ )
381
+ tl .store (row_major_t_out_ptr + row_major_t_offs , fp8_vals .trans (1 , 0 ), mask = mask )
382
+
383
+
308
384
@triton .autotune (configs = kernel_configs_1D , key = ["num_elements" ])
309
385
@triton .jit
310
386
def _amax_atomic (
@@ -701,6 +777,88 @@ def hp_to_fp8_row_and_col_major(
701
777
return fp8_tensor_row_major , fp8_tensor_col_major
702
778
703
779
780
+ def hp_to_fp8_row_major_t_and_non_t (
781
+ hp_tensor : torch .Tensor ,
782
+ fp8_dtype : torch .dtype ,
783
+ linear_mm_config : LinearMMConfig ,
784
+ gemm_input_role : GemmInputRole = GemmInputRole .INPUT ,
785
+ algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ,
786
+ ) -> Float8Tensor :
787
+ assert hp_tensor .is_contiguous (), "input tensor must be contiguous"
788
+
789
+ tl_input_dtype = FP8_DTYPE_MAP [hp_tensor .dtype ]
790
+ tl_output_dtype = FP8_DTYPE_MAP [fp8_dtype ]
791
+
792
+ fp8_dtype_min = torch .finfo (fp8_dtype ).min
793
+ fp8_dtype_max = torch .finfo (fp8_dtype ).max
794
+
795
+ # compute scaling factor for tensor
796
+ scale = _hp_tensor_to_scale (
797
+ hp_tensor ,
798
+ tl_input_dtype ,
799
+ fp8_dtype_max ,
800
+ algo ,
801
+ )
802
+
803
+ # perform fp8 conversion
804
+ input_num_rows , input_num_cols = hp_tensor .shape
805
+ transposed_num_rows , transposed_num_cols = input_num_cols , input_num_rows
806
+ num_elements = hp_tensor .numel ()
807
+
808
+ # preallocate necessary output tensors
809
+ fp8_output_row_major = torch .empty (
810
+ (input_num_rows , input_num_cols ), dtype = fp8_dtype , device = hp_tensor .device
811
+ )
812
+ fp8_output_row_major_t = torch .empty (
813
+ (transposed_num_rows , transposed_num_cols ),
814
+ dtype = fp8_dtype ,
815
+ device = hp_tensor .device ,
816
+ )
817
+
818
+ # launch triton kernel to perform conversion
819
+ grid = lambda meta : (
820
+ triton .cdiv (input_num_rows , meta ["BLOCK_SIZE_ROWS" ]),
821
+ triton .cdiv (input_num_cols , meta ["BLOCK_SIZE_COLS" ]),
822
+ )
823
+ _to_fp8_row_major_t_and_non_t [grid ](
824
+ hp_tensor ,
825
+ fp8_output_row_major ,
826
+ fp8_output_row_major_t ,
827
+ scale ,
828
+ num_elements ,
829
+ fp8_dtype_min ,
830
+ fp8_dtype_max ,
831
+ input_num_rows ,
832
+ input_num_cols ,
833
+ hp_tensor .stride (0 ),
834
+ hp_tensor .stride (1 ),
835
+ fp8_output_row_major .stride (0 ),
836
+ fp8_output_row_major .stride (1 ),
837
+ fp8_output_row_major_t .stride (0 ),
838
+ fp8_output_row_major_t .stride (1 ),
839
+ input_dtype = tl_input_dtype ,
840
+ output_dtype = tl_output_dtype ,
841
+ EPS = EPS ,
842
+ )
843
+
844
+ # wrap outputs in Float8Tensors
845
+ fp8_tensor_row_major = Float8Tensor (
846
+ fp8_output_row_major ,
847
+ scale ,
848
+ orig_dtype = hp_tensor .dtype ,
849
+ linear_mm_config = linear_mm_config ,
850
+ gemm_input_role = gemm_input_role ,
851
+ )
852
+ fp8_tensor_row_major_t = Float8Tensor (
853
+ fp8_output_row_major_t ,
854
+ scale ,
855
+ orig_dtype = hp_tensor .dtype ,
856
+ linear_mm_config = linear_mm_config ,
857
+ gemm_input_role = gemm_input_role ,
858
+ )
859
+ return fp8_tensor_row_major , fp8_tensor_row_major_t
860
+
861
+
704
862
def _hp_tensor_to_scale (
705
863
hp_tensor : torch .Tensor ,
706
864
tl_input_dtype : tl .core .dtype ,
0 commit comments