Skip to content

Commit 37ea241

Browse files
Regenerate MLIR Bindings (#1419)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 667610b commit 37ea241

File tree

10 files changed

+428
-37
lines changed

10 files changed

+428
-37
lines changed

src/mlir/Dialects/Arith.jl

Lines changed: 142 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ The `addi` operation takes two operands and returns one result, each of
7777
these is required to be the same type. This type may be an integer scalar type,
7878
a vector whose element type is integer, or a tensor of integers.
7979
80-
This op supports `nuw`/`nsw` overflow flags which stands stand for
80+
This op supports `nuw`/`nsw` overflow flags which stands for
8181
\"No Unsigned Wrap\" and \"No Signed Wrap\", respectively. If the `nuw` and/or
8282
`nsw` flags are present, and an unsigned/signed overflow occurs
8383
(respectively), the result is poison.
@@ -1193,7 +1193,7 @@ The `muli` operation takes two operands and returns one result, each of
11931193
these is required to be the same type. This type may be an integer scalar type,
11941194
a vector whose element type is integer, or a tensor of integers.
11951195
1196-
This op supports `nuw`/`nsw` overflow flags which stands stand for
1196+
This op supports `nuw`/`nsw` overflow flags which stands for
11971197
\"No Unsigned Wrap\" and \"No Signed Wrap\", respectively. If the `nuw` and/or
11981198
`nsw` flags are present, and an unsigned/signed overflow occurs
11991199
(respectively), the result is poison.
@@ -1578,6 +1578,129 @@ function sitofp(in::Value; out::IR.Type, location=Location())
15781578
)
15791579
end
15801580

1581+
"""
1582+
`scaling_extf`
1583+
1584+
This operation upcasts input floating-point values using provided scale
1585+
values. It expects both scales and the input operand to be of the same shape,
1586+
making the operation elementwise. Scales are usually calculated per block
1587+
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1588+
1589+
If scales are calculated per block where blockSize != 1, then scales may
1590+
require broadcasting to make this operation elementwise. For example, let\'s
1591+
say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1592+
assuming quantization happens on the last axis, the input can be reshaped to
1593+
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1594+
per block on the last axis. Therefore, scales will be of shape
1595+
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1596+
shape as long as it is broadcast compatible with the input, e.g.,
1597+
`<1 x 1 x ... (dimN/blockSize) x 1>`.
1598+
1599+
In this example, before calling into `arith.scaling_extf`, scales must be
1600+
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1601+
that there could be multiple quantization axes. Internally,
1602+
`arith.scaling_extf` would perform the following:
1603+
1604+
```
1605+
resultTy = get_type(result)
1606+
scaleTy = get_type(scale)
1607+
inputTy = get_type(input)
1608+
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1609+
scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
1610+
input.extf = arith.extf(input) : inputTy to resultTy
1611+
result = arith.mulf(scale.extf, input.extf)
1612+
```
1613+
It propagates NaN values. Therefore, if either scale or the input element
1614+
contains NaN, then the output element value will also be a NaN.
1615+
"""
1616+
function scaling_extf(
1617+
in::Value, scale::Value; out::IR.Type, fastmath=nothing, location=Location()
1618+
)
1619+
op_ty_results = IR.Type[out,]
1620+
operands = Value[in, scale]
1621+
owned_regions = Region[]
1622+
successors = Block[]
1623+
attributes = NamedAttribute[]
1624+
!isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath))
1625+
1626+
return create_operation(
1627+
"arith.scaling_extf",
1628+
location;
1629+
operands,
1630+
owned_regions,
1631+
successors,
1632+
attributes,
1633+
results=op_ty_results,
1634+
result_inference=false,
1635+
)
1636+
end
1637+
1638+
"""
1639+
`scaling_truncf`
1640+
1641+
This operation downcasts input using the provided scale values. It expects
1642+
both scales and the input operand to be of the same shape and, therefore,
1643+
makes the operation elementwise. Scales are usually calculated per block
1644+
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1645+
Users are required to normalize and clamp the scales as necessary before calling
1646+
passing them to this operation. OCP MXFP spec also does the flushing of denorms
1647+
on the input operand, which should be handled during lowering by passing appropriate
1648+
fastMath flag to this operation.
1649+
1650+
If scales are calculated per block where blockSize != 1, scales may require
1651+
broadcasting to make this operation elementwise. For example, let\'s say the
1652+
input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1653+
assuming quantization happens on the last axis, the input can be reshaped to
1654+
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1655+
per block on the last axis. Therefore, scales will be of shape
1656+
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1657+
shape as long as it is broadcast compatible with the input, e.g.,
1658+
`<1 x 1 x ... (dimN/blockSize) x 1>`.
1659+
1660+
In this example, before calling into `arith.scaling_truncf`, scales must be
1661+
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1662+
that there could be multiple quantization axes. Internally,
1663+
`arith.scaling_truncf` would perform the following:
1664+
1665+
```
1666+
scaleTy = get_type(scale)
1667+
inputTy = get_type(input)
1668+
resultTy = get_type(result)
1669+
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1670+
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
1671+
result = arith.divf(input, scale.extf)
1672+
result.cast = arith.truncf(result, resultTy)
1673+
```
1674+
"""
1675+
function scaling_truncf(
1676+
in::Value,
1677+
scale::Value;
1678+
out::IR.Type,
1679+
roundingmode=nothing,
1680+
fastmath=nothing,
1681+
location=Location(),
1682+
)
1683+
op_ty_results = IR.Type[out,]
1684+
operands = Value[in, scale]
1685+
owned_regions = Region[]
1686+
successors = Block[]
1687+
attributes = NamedAttribute[]
1688+
!isnothing(roundingmode) &&
1689+
push!(attributes, namedattribute("roundingmode", roundingmode))
1690+
!isnothing(fastmath) && push!(attributes, namedattribute("fastmath", fastmath))
1691+
1692+
return create_operation(
1693+
"arith.scaling_truncf",
1694+
location;
1695+
operands,
1696+
owned_regions,
1697+
successors,
1698+
attributes,
1699+
results=op_ty_results,
1700+
result_inference=false,
1701+
)
1702+
end
1703+
15811704
"""
15821705
`shli`
15831706
@@ -1587,7 +1710,7 @@ unsigned. The low order bits are filled with zeros. If the value of the second
15871710
operand is greater or equal than the bitwidth of the first operand, then the
15881711
operation returns poison.
15891712
1590-
This op supports `nuw`/`nsw` overflow flags which stands stand for
1713+
This op supports `nuw`/`nsw` overflow flags which stands for
15911714
\"No Unsigned Wrap\" and \"No Signed Wrap\", respectively. If the `nuw` and/or
15921715
`nsw` flags are present, and an unsigned/signed overflow occurs
15931716
(respectively), the result is poison.
@@ -1775,7 +1898,7 @@ The `subi` operation takes two operands and returns one result, each of
17751898
these is required to be the same type. This type may be an integer scalar type,
17761899
a vector whose element type is integer, or a tensor of integers.
17771900
1778-
This op supports `nuw`/`nsw` overflow flags which stands stand for
1901+
This op supports `nuw`/`nsw` overflow flags which stands for
17791902
\"No Unsigned Wrap\" and \"No Signed Wrap\", respectively. If the `nuw` and/or
17801903
`nsw` flags are present, and an unsigned/signed overflow occurs
17811904
(respectively), the result is poison.
@@ -1865,22 +1988,35 @@ width M and an integer destination type of width N. The destination
18651988
bit-width must be smaller than the input bit-width (N < M).
18661989
The top-most (N - M) bits of the input are discarded.
18671990
1991+
This op supports `nuw`/`nsw` overflow flags which stands for \"No Unsigned
1992+
Wrap\" and \"No Signed Wrap\", respectively. If the nuw keyword is present,
1993+
and any of the truncated bits are non-zero, the result is a poison value.
1994+
If the nsw keyword is present, and any of the truncated bits are not the
1995+
same as the top bit of the truncation result, the result is a poison value.
1996+
18681997
# Example
18691998
18701999
```mlir
2000+
// Scalar truncation.
18712001
%1 = arith.constant 21 : i5 // %1 is 0b10101
18722002
%2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101
18732003
%3 = arith.trunci %1 : i5 to i3 // %3 is 0b101
18742004
1875-
%5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
2005+
// Vector truncation.
2006+
%4 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
2007+
2008+
// Scalar truncation with overflow flags.
2009+
%5 = arith.trunci %a overflow<nsw, nuw> : i32 to i16
18762010
```
18772011
"""
1878-
function trunci(in::Value; out::IR.Type, location=Location())
2012+
function trunci(in::Value; out::IR.Type, overflowFlags=nothing, location=Location())
18792013
op_ty_results = IR.Type[out,]
18802014
operands = Value[in,]
18812015
owned_regions = Region[]
18822016
successors = Block[]
18832017
attributes = NamedAttribute[]
2018+
!isnothing(overflowFlags) &&
2019+
push!(attributes, namedattribute("overflowFlags", overflowFlags))
18842020

18852021
return create_operation(
18862022
"arith.trunci",

src/mlir/Dialects/Builtin.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ A `module` represents a top-level container operation. It contains a single
2020
[graph region](../LangRef.md#control-flow-and-ssacfg-regions) containing a single block
2121
which can contain any operations and does not have a terminator. Operations
2222
within this region cannot implicitly capture values defined outside the module,
23-
i.e. Modules are [IsolatedFromAbove](../Traits.md#isolatedfromabove). Modules have
23+
i.e. Modules are [IsolatedFromAbove](../Traits#isolatedfromabove). Modules have
2424
an optional [symbol name](../SymbolsAndSymbolTables.md) which can be used to refer
2525
to them in operations.
2626

src/mlir/Dialects/Gpu.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ end
130130
"""
131131
`barrier`
132132
133-
The \"barrier\" op synchronizes all work items of a workgroup. It is used
133+
The `barrier` op synchronizes all work items of a workgroup. It is used
134134
to coordinate communication between the work items of the workgroup.
135135
136136
```mlir
@@ -322,7 +322,7 @@ Returns the block id within the cluster along the x, y, or z `dimension`.
322322
```
323323
324324
If `upper_bound` is set, then executing (a lowering of) this operation in an
325-
environment where the number of thread blocks per cluster along `dimension`
325+
environment where the number of thread blocks per cluster along `dimension`
326326
is greater than `upper_bound` causes undefined behavior.
327327
328328
There is an implicit upper bound of `kMaxClusterDim` (currently 8).
@@ -1905,7 +1905,7 @@ end
19051905
"""
19061906
`return_`
19071907
1908-
A terminator operation for regions that appear in the body of `gpu.func`
1908+
A terminator operation for regions that appear in the body of `gpu.func`
19091909
functions. The operands to the `gpu.return` are the result values returned
19101910
by an invocation of the `gpu.func`.
19111911
"""
@@ -2141,7 +2141,8 @@ trades value with exactly one other lane.
21412141
%3, %4 = gpu.shuffle down %0, %cst1, %width : f32
21422142
```
21432143
2144-
For lane `k`, returns the value from lane `(k + 1) % width`.
2144+
For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is
2145+
bigger than or equal to `width`, the value is poison and `valid` is `false`.
21452146
21462147
`up` example:
21472148
@@ -2150,7 +2151,8 @@ For lane `k`, returns the value from lane `(k + 1) % width`.
21502151
%5, %6 = gpu.shuffle up %0, %cst1, %width : f32
21512152
```
21522153
2153-
For lane `k`, returns the value from lane `(k - 1) % width`.
2154+
For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is
2155+
smaller than `0`, the value is poison and `valid` is `false`.
21542156
21552157
`idx` example:
21562158
@@ -3412,7 +3414,7 @@ end
34123414
"""
34133415
`yield`
34143416
3415-
gpu.yield` is a special terminator operation for blocks inside regions
3417+
`gpu.yield` is a special terminator operation for blocks inside regions
34163418
in gpu ops. It returns values to the immediately enclosing gpu op.
34173419
34183420
# Example

src/mlir/Dialects/Llvm.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,6 @@ function call(
580580
var_callee_type=nothing,
581581
callee=nothing,
582582
fastmathFlags=nothing,
583-
branch_weights=nothing,
584583
CConv=nothing,
585584
TailCallKind=nothing,
586585
memory_effects=nothing,
@@ -615,8 +614,6 @@ function call(
615614
!isnothing(callee) && push!(attributes, namedattribute("callee", callee))
616615
!isnothing(fastmathFlags) &&
617616
push!(attributes, namedattribute("fastmathFlags", fastmathFlags))
618-
!isnothing(branch_weights) &&
619-
push!(attributes, namedattribute("branch_weights", branch_weights))
620617
!isnothing(CConv) && push!(attributes, namedattribute("CConv", CConv))
621618
!isnothing(TailCallKind) &&
622619
push!(attributes, namedattribute("TailCallKind", TailCallKind))
@@ -1854,6 +1851,8 @@ function func(;
18541851
frame_pointer=nothing,
18551852
target_cpu=nothing,
18561853
tune_cpu=nothing,
1854+
reciprocal_estimates=nothing,
1855+
prefer_vector_width=nothing,
18571856
target_features=nothing,
18581857
unsafe_fp_math=nothing,
18591858
no_infs_fp_math=nothing,
@@ -1927,6 +1926,10 @@ function func(;
19271926
push!(attributes, namedattribute("frame_pointer", frame_pointer))
19281927
!isnothing(target_cpu) && push!(attributes, namedattribute("target_cpu", target_cpu))
19291928
!isnothing(tune_cpu) && push!(attributes, namedattribute("tune_cpu", tune_cpu))
1929+
!isnothing(reciprocal_estimates) &&
1930+
push!(attributes, namedattribute("reciprocal_estimates", reciprocal_estimates))
1931+
!isnothing(prefer_vector_width) &&
1932+
push!(attributes, namedattribute("prefer_vector_width", prefer_vector_width))
19301933
!isnothing(target_features) &&
19311934
push!(attributes, namedattribute("target_features", target_features))
19321935
!isnothing(unsafe_fp_math) &&

src/mlir/Dialects/MemRef.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ end
327327
The `alloca` operation allocates memory on the stack, to be automatically
328328
released when control transfers back from the region of its closest
329329
surrounding operation with an
330-
[`AutomaticAllocationScope`](../Traits.md/#automaticallocationscope) trait.
330+
[`AutomaticAllocationScope`](../Traits/#automaticallocationscope) trait.
331331
The amount of memory allocated is specified by its memref and additional
332332
operands. For example:
333333

0 commit comments

Comments
 (0)