@@ -77,7 +77,7 @@ The `addi` operation takes two operands and returns one result, each of
77
77
these is required to be the same type. This type may be an integer scalar type,
78
78
a vector whose element type is integer, or a tensor of integers.
79
79
80
- This op supports `nuw`/`nsw` overflow flags which stands stand for
80
+ This op supports `nuw`/`nsw` overflow flags which stands for
81
81
\" No Unsigned Wrap\" and \" No Signed Wrap\" , respectively. If the `nuw` and/or
82
82
`nsw` flags are present, and an unsigned/signed overflow occurs
83
83
(respectively), the result is poison.
@@ -1193,7 +1193,7 @@ The `muli` operation takes two operands and returns one result, each of
1193
1193
these is required to be the same type. This type may be an integer scalar type,
1194
1194
a vector whose element type is integer, or a tensor of integers.
1195
1195
1196
- This op supports `nuw`/`nsw` overflow flags which stands stand for
1196
+ This op supports `nuw`/`nsw` overflow flags which stands for
1197
1197
\" No Unsigned Wrap\" and \" No Signed Wrap\" , respectively. If the `nuw` and/or
1198
1198
`nsw` flags are present, and an unsigned/signed overflow occurs
1199
1199
(respectively), the result is poison.
@@ -1578,6 +1578,129 @@ function sitofp(in::Value; out::IR.Type, location=Location())
1578
1578
)
1579
1579
end
1580
1580
1581
+ """
1582
+ `scaling_extf`
1583
+
1584
+ This operation upcasts input floating-point values using provided scale
1585
+ values. It expects both scales and the input operand to be of the same shape,
1586
+ making the operation elementwise. Scales are usually calculated per block
1587
+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1588
+
1589
+ If scales are calculated per block where blockSize != 1, then scales may
1590
+ require broadcasting to make this operation elementwise. For example, let\' s
1591
+ say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1592
+ assuming quantization happens on the last axis, the input can be reshaped to
1593
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1594
+ per block on the last axis. Therefore, scales will be of shape
1595
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1596
+ shape as long as it is broadcast compatible with the input, e.g.,
1597
+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
1598
+
1599
+ In this example, before calling into `arith.scaling_extf`, scales must be
1600
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1601
+ that there could be multiple quantization axes. Internally,
1602
+ `arith.scaling_extf` would perform the following:
1603
+
1604
+ ```
1605
+ resultTy = get_type(result)
1606
+ scaleTy = get_type(scale)
1607
+ inputTy = get_type(input)
1608
+ scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1609
+ scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
1610
+ input.extf = arith.extf(input) : inputTy to resultTy
1611
+ result = arith.mulf(scale.extf, input.extf)
1612
+ ```
1613
+ It propagates NaN values. Therefore, if either scale or the input element
1614
+ contains NaN, then the output element value will also be a NaN.
1615
+ """
1616
+ function scaling_extf (
1617
+ in:: Value , scale:: Value ; out:: IR.Type , fastmath= nothing , location= Location ()
1618
+ )
1619
+ op_ty_results = IR. Type[out,]
1620
+ operands = Value[in, scale]
1621
+ owned_regions = Region[]
1622
+ successors = Block[]
1623
+ attributes = NamedAttribute[]
1624
+ ! isnothing (fastmath) && push! (attributes, namedattribute (" fastmath" , fastmath))
1625
+
1626
+ return create_operation (
1627
+ " arith.scaling_extf" ,
1628
+ location;
1629
+ operands,
1630
+ owned_regions,
1631
+ successors,
1632
+ attributes,
1633
+ results= op_ty_results,
1634
+ result_inference= false ,
1635
+ )
1636
+ end
1637
+
1638
+ """
1639
+ `scaling_truncf`
1640
+
1641
+ This operation downcasts input using the provided scale values. It expects
1642
+ both scales and the input operand to be of the same shape and, therefore,
1643
+ makes the operation elementwise. Scales are usually calculated per block
1644
+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1645
+ Users are required to normalize and clamp the scales as necessary before calling
1646
+ passing them to this operation. OCP MXFP spec also does the flushing of denorms
1647
+ on the input operand, which should be handled during lowering by passing appropriate
1648
+ fastMath flag to this operation.
1649
+
1650
+ If scales are calculated per block where blockSize != 1, scales may require
1651
+ broadcasting to make this operation elementwise. For example, let\' s say the
1652
+ input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1653
+ assuming quantization happens on the last axis, the input can be reshaped to
1654
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1655
+ per block on the last axis. Therefore, scales will be of shape
1656
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1657
+ shape as long as it is broadcast compatible with the input, e.g.,
1658
+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
1659
+
1660
+ In this example, before calling into `arith.scaling_truncf`, scales must be
1661
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1662
+ that there could be multiple quantization axes. Internally,
1663
+ `arith.scaling_truncf` would perform the following:
1664
+
1665
+ ```
1666
+ scaleTy = get_type(scale)
1667
+ inputTy = get_type(input)
1668
+ resultTy = get_type(result)
1669
+ scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1670
+ scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
1671
+ result = arith.divf(input, scale.extf)
1672
+ result.cast = arith.truncf(result, resultTy)
1673
+ ```
1674
+ """
1675
+ function scaling_truncf (
1676
+ in:: Value ,
1677
+ scale:: Value ;
1678
+ out:: IR.Type ,
1679
+ roundingmode= nothing ,
1680
+ fastmath= nothing ,
1681
+ location= Location (),
1682
+ )
1683
+ op_ty_results = IR. Type[out,]
1684
+ operands = Value[in, scale]
1685
+ owned_regions = Region[]
1686
+ successors = Block[]
1687
+ attributes = NamedAttribute[]
1688
+ ! isnothing (roundingmode) &&
1689
+ push! (attributes, namedattribute (" roundingmode" , roundingmode))
1690
+ ! isnothing (fastmath) && push! (attributes, namedattribute (" fastmath" , fastmath))
1691
+
1692
+ return create_operation (
1693
+ " arith.scaling_truncf" ,
1694
+ location;
1695
+ operands,
1696
+ owned_regions,
1697
+ successors,
1698
+ attributes,
1699
+ results= op_ty_results,
1700
+ result_inference= false ,
1701
+ )
1702
+ end
1703
+
1581
1704
"""
1582
1705
`shli`
1583
1706
@@ -1587,7 +1710,7 @@ unsigned. The low order bits are filled with zeros. If the value of the second
1587
1710
operand is greater or equal than the bitwidth of the first operand, then the
1588
1711
operation returns poison.
1589
1712
1590
- This op supports `nuw`/`nsw` overflow flags which stands stand for
1713
+ This op supports `nuw`/`nsw` overflow flags which stands for
1591
1714
\" No Unsigned Wrap\" and \" No Signed Wrap\" , respectively. If the `nuw` and/or
1592
1715
`nsw` flags are present, and an unsigned/signed overflow occurs
1593
1716
(respectively), the result is poison.
@@ -1775,7 +1898,7 @@ The `subi` operation takes two operands and returns one result, each of
1775
1898
these is required to be the same type. This type may be an integer scalar type,
1776
1899
a vector whose element type is integer, or a tensor of integers.
1777
1900
1778
- This op supports `nuw`/`nsw` overflow flags which stands stand for
1901
+ This op supports `nuw`/`nsw` overflow flags which stands for
1779
1902
\" No Unsigned Wrap\" and \" No Signed Wrap\" , respectively. If the `nuw` and/or
1780
1903
`nsw` flags are present, and an unsigned/signed overflow occurs
1781
1904
(respectively), the result is poison.
@@ -1865,22 +1988,35 @@ width M and an integer destination type of width N. The destination
1865
1988
bit-width must be smaller than the input bit-width (N < M).
1866
1989
The top-most (N - M) bits of the input are discarded.
1867
1990
1991
+ This op supports `nuw`/`nsw` overflow flags which stands for \" No Unsigned
1992
+ Wrap\" and \" No Signed Wrap\" , respectively. If the nuw keyword is present,
1993
+ and any of the truncated bits are non-zero, the result is a poison value.
1994
+ If the nsw keyword is present, and any of the truncated bits are not the
1995
+ same as the top bit of the truncation result, the result is a poison value.
1996
+
1868
1997
# Example
1869
1998
1870
1999
```mlir
2000
+ // Scalar truncation.
1871
2001
%1 = arith.constant 21 : i5 // %1 is 0b10101
1872
2002
%2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101
1873
2003
%3 = arith.trunci %1 : i5 to i3 // %3 is 0b101
1874
2004
1875
- %5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
2005
+ // Vector truncation.
2006
+ %4 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
2007
+
2008
+ // Scalar truncation with overflow flags.
2009
+ %5 = arith.trunci %a overflow<nsw, nuw> : i32 to i16
1876
2010
```
1877
2011
"""
1878
- function trunci (in:: Value ; out:: IR.Type , location= Location ())
2012
+ function trunci (in:: Value ; out:: IR.Type , overflowFlags = nothing , location= Location ())
1879
2013
op_ty_results = IR. Type[out,]
1880
2014
operands = Value[in,]
1881
2015
owned_regions = Region[]
1882
2016
successors = Block[]
1883
2017
attributes = NamedAttribute[]
2018
+ ! isnothing (overflowFlags) &&
2019
+ push! (attributes, namedattribute (" overflowFlags" , overflowFlags))
1884
2020
1885
2021
return create_operation (
1886
2022
" arith.trunci" ,
0 commit comments