Skip to content

Commit 840a1f4

Browse files
Mike PallBuristan
Mike Pall
authored andcommitted
Disable FMA by default. Use -Ofma or jit.opt.start("+fma") to enable.
See the discussion in the corresponding ticket for the rationale. (cherry picked from commit de2e1ca) For the modulo operation, the arm64 VM uses `fmsub` [1] instruction, which is the fused multiply-add (FMA [2]) operation (more precisely, multiply-sub). Hence, it may produce different results compared to the unfused one. This patch fixes the behaviour by using the unfused instructions by default. However, the new JIT optimization flag (fma) is introduced to make it possible to take advantage of the FMA optimizations. Sergey Kaplun: * added the description and the test for the problem [1]: https://developer.arm.com/documentation/dui0801/g/A64-Floating-point-Instructions/FMSUB [2]: https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation Part of tarantool/tarantool#10709 Reviewed-by: Sergey Bronnikov <[email protected]> Signed-off-by: Sergey Kaplun <[email protected]> (cherry picked from commit 58b013a)
1 parent 73674ed commit 840a1f4

10 files changed

+151
-6
lines changed

doc/running.html

+8
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ <h3 id="opt_O"><tt>-O[level]</tt><br>
226226
overrides all earlier flags.
227227
</p>
228228
<p>
229+
Note that <tt>-Ofma</tt> is not enabled by default at any level,
230+
because it affects floating-point result accuracy. Only enable this,
231+
if you fully understand the trade-offs of FMA for performance (higher),
232+
determinism (lower) and numerical accuracy (higher).
233+
</p>
234+
<p>
229235
Here are the available flags and at what optimization levels they
230236
are enabled:
231237
</p>
@@ -257,6 +263,8 @@ <h3 id="opt_O"><tt>-O[level]</tt><br>
257263
<td class="flag_name">sink</td><td class="flag_level">&nbsp;</td><td class="flag_level">&nbsp;</td><td class="flag_level">&bull;</td><td class="flag_desc">Allocation/Store Sinking</td></tr>
258264
<tr class="even">
259265
<td class="flag_name">fuse</td><td class="flag_level">&nbsp;</td><td class="flag_level">&nbsp;</td><td class="flag_level">&bull;</td><td class="flag_desc">Fusion of operands into instructions</td></tr>
266+
<tr class="odd">
267+
<td class="flag_name">fma </td><td class="flag_level">&nbsp;</td><td class="flag_level">&nbsp;</td><td class="flag_level">&nbsp;</td><td class="flag_desc">Fused multiply-add</td></tr>
260268
</table>
261269
<p>
262270
Here are the parameters and their default settings:

src/lj_asm_arm.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,11 @@ static void asm_fusexref(ASMState *as, ARMIns ai, Reg rd, IRRef ref,
310310
}
311311

312312
#if !LJ_SOFTFP
313-
/* Fuse to multiply-add/sub instruction. */
313+
/*
314+
** Fuse to multiply-add/sub instruction.
315+
** VMLA rounds twice (UMA, not FMA) -- no need to check for JIT_F_OPT_FMA.
316+
** VFMA needs VFPv4, which is uncommon on the remaining ARM32 targets.
317+
*/
314318
static int asm_fusemadd(ASMState *as, IRIns *ir, ARMIns ai, ARMIns air)
315319
{
316320
IRRef lref = ir->op1, rref = ir->op2;

src/lj_asm_arm64.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ static int asm_fusemadd(ASMState *as, IRIns *ir, A64Ins ai, A64Ins air)
334334
{
335335
IRRef lref = ir->op1, rref = ir->op2;
336336
IRIns *irm;
337-
if (lref != rref &&
337+
if ((as->flags & JIT_F_OPT_FMA) &&
338+
lref != rref &&
338339
((mayfuse(as, lref) && (irm = IR(lref), irm->o == IR_MUL) &&
339340
ra_noreg(irm->r)) ||
340341
(mayfuse(as, rref) && (irm = IR(rref), irm->o == IR_MUL) &&

src/lj_asm_ppc.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ static int asm_fusemadd(ASMState *as, IRIns *ir, PPCIns pi, PPCIns pir)
232232
{
233233
IRRef lref = ir->op1, rref = ir->op2;
234234
IRIns *irm;
235-
if (lref != rref &&
235+
if ((as->flags & JIT_F_OPT_FMA) &&
236+
lref != rref &&
236237
((mayfuse(as, lref) && (irm = IR(lref), irm->o == IR_MUL) &&
237238
ra_noreg(irm->r)) ||
238239
(mayfuse(as, rref) && (irm = IR(rref), irm->o == IR_MUL) &&

src/lj_jit.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@
8686
#define JIT_F_OPT_ABC (JIT_F_OPT << 7)
8787
#define JIT_F_OPT_SINK (JIT_F_OPT << 8)
8888
#define JIT_F_OPT_FUSE (JIT_F_OPT << 9)
89+
#define JIT_F_OPT_FMA (JIT_F_OPT << 10)
8990

9091
/* Optimizations names for -O. Must match the order above. */
9192
#define JIT_F_OPTSTRING \
92-
"\4fold\3cse\3dce\3fwd\3dse\6narrow\4loop\3abc\4sink\4fuse"
93+
"\4fold\3cse\3dce\3fwd\3dse\6narrow\4loop\3abc\4sink\4fuse\3fma"
9394

9495
/* Optimization levels set a fixed combination of flags. */
9596
#define JIT_F_OPT_0 0
@@ -98,6 +99,7 @@
9899
#define JIT_F_OPT_3 (JIT_F_OPT_2|\
99100
JIT_F_OPT_FWD|JIT_F_OPT_DSE|JIT_F_OPT_ABC|JIT_F_OPT_SINK|JIT_F_OPT_FUSE)
100101
#define JIT_F_OPT_DEFAULT JIT_F_OPT_3
102+
/* Note: FMA is not set by default. */
101103

102104
/* -- JIT engine parameters ----------------------------------------------- */
103105

src/lj_vmmath.c

+12-1
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,25 @@ LJ_FUNCA double lj_wrap_fmod(double x, double y) { return fmod(x, y); }
3636

3737
/* -- Helper functions ---------------------------------------------------- */
3838

39+
/* Required to prevent the C compiler from applying FMA optimizations.
40+
**
41+
** Yes, there's -ffp-contract and the FP_CONTRACT pragma ... in theory.
42+
** But the current state of C compilers is a mess in this regard.
43+
** Also, this function is not performance sensitive at all.
44+
*/
45+
LJ_NOINLINE static double lj_vm_floormul(double x, double y)
46+
{
47+
return lj_vm_floor(x / y) * y;
48+
}
49+
3950
double lj_vm_foldarith(double x, double y, int op)
4051
{
4152
switch (op) {
4253
case IR_ADD - IR_ADD: return x+y; break;
4354
case IR_SUB - IR_ADD: return x-y; break;
4455
case IR_MUL - IR_ADD: return x*y; break;
4556
case IR_DIV - IR_ADD: return x/y; break;
46-
case IR_MOD - IR_ADD: return x-lj_vm_floor(x/y)*y; break;
57+
case IR_MOD - IR_ADD: return x-lj_vm_floormul(x, y); break;
4758
case IR_POW - IR_ADD: return pow(x, y); break;
4859
case IR_NEG - IR_ADD: return -x; break;
4960
case IR_ABS - IR_ADD: return fabs(x); break;

src/vm_arm64.dasc

+3-1
Original file line numberDiff line numberDiff line change
@@ -2581,7 +2581,9 @@ static void build_ins(BuildCtx *ctx, BCOp op, int defop)
25812581
|.macro ins_arithmod, res, reg1, reg2
25822582
| fdiv d2, reg1, reg2
25832583
| frintm d2, d2
2584-
| fmsub res, d2, reg2, reg1
2584+
| // Cannot use fmsub, because FMA is not enabled by default.
2585+
| fmul d2, d2, reg2
2586+
| fsub res, reg1, d2
25852587
|.endmacro
25862588
|
25872589
|.macro ins_arithdn, intins, fpins
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
local tap = require('tap')
2+
3+
-- Test file to demonstrate consistent behaviour for JIT and the
4+
-- VM regarding FMA optimization (disabled by default).
5+
-- XXX: The VM behaviour is checked in the
6+
-- <lj-918-fma-numerical-accuracy.test.lua>.
7+
-- See also: https://github.com/LuaJIT/LuaJIT/issues/918.
8+
local test = tap.test('lj-918-fma-numerical-accuracy-jit'):skipcond({
9+
['Test requires JIT enabled'] = not jit.status(),
10+
})
11+
12+
test:plan(1)
13+
14+
local _2pow52 = 2 ^ 52
15+
16+
-- XXX: Before this commit the LuaJIT arm64 VM uses `fmsub` [1]
17+
-- instruction for the modulo operation, which is the fused
18+
-- multiply-add (FMA [2]) operation (more precisely,
19+
-- multiply-sub). Hence, it may produce different results compared
20+
-- to the unfused one. For the test, let's just use 2 numbers in
21+
-- modulo for which the single rounding is different from the
22+
-- double rounding. The numbers from the original issue are good
23+
-- enough.
24+
--
25+
-- [1]:https://developer.arm.com/documentation/dui0801/g/A64-Floating-point-Instructions/FMSUB
26+
-- [2]:https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation
27+
--
28+
-- IEEE754 components to double:
29+
-- sign * (2 ^ (exp - 1023)) * (mantissa / _2pow52 + normal).
30+
local a = 1 * (2 ^ (1083 - 1023)) * (4080546448249347 / _2pow52 + 1)
31+
assert(a == 2197541395358679800)
32+
33+
local b = -1 * (2 ^ (1052 - 1023)) * (3927497732209973 / _2pow52 + 1)
34+
assert(b == -1005065126.3690554)
35+
36+
local results = {}
37+
38+
jit.opt.start('hotloop=1')
39+
for i = 1, 4 do
40+
results[i] = a % b
41+
end
42+
43+
-- XXX: The test doesn't fail before this commit. But it is
44+
-- required to be sure that there are no inconsistencies after the
45+
-- commit.
46+
test:samevalues(results, 'consistent behaviour between the JIT and the VM')
47+
48+
test:done(true)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
local tap = require('tap')
2+
3+
-- Test file to demonstrate possible numerical inaccuracy if FMA
4+
-- optimization takes place.
5+
-- XXX: The JIT consistency is checked in the
6+
-- <lj-918-fma-numerical-accuracy-jit.test.lua>.
7+
-- See also: https://github.com/LuaJIT/LuaJIT/issues/918.
8+
local test = tap.test('lj-918-fma-numerical-accuracy')
9+
10+
test:plan(2)
11+
12+
local _2pow52 = 2 ^ 52
13+
14+
-- XXX: Before this commit the LuaJIT arm64 VM uses `fmsub` [1]
15+
-- instruction for the modulo operation, which is the fused
16+
-- multiply-add (FMA [2]) operation (more precisely,
17+
-- multiply-sub). Hence, it may produce different results compared
18+
-- to the unfused one. For the test, let's just use 2 numbers in
19+
-- modulo for which the single rounding is different from the
20+
-- double rounding. The numbers from the original issue are good
21+
-- enough.
22+
--
23+
-- [1]:https://developer.arm.com/documentation/dui0801/g/A64-Floating-point-Instructions/FMSUB
24+
-- [2]:https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation
25+
--
26+
-- IEEE754 components to double:
27+
-- sign * (2 ^ (exp - 1023)) * (mantissa / _2pow52 + normal).
28+
local a = 1 * (2 ^ (1083 - 1023)) * (4080546448249347 / _2pow52 + 1)
29+
assert(a == 2197541395358679800)
30+
31+
local b = -1 * (2 ^ (1052 - 1023)) * (3927497732209973 / _2pow52 + 1)
32+
assert(b == -1005065126.3690554)
33+
34+
-- These tests fail on ARM64 before this patch or with FMA
35+
-- optimization enabled.
36+
-- The first test may not fail if the compiler doesn't generate
37+
-- an ARM64 FMA operation in `lj_vm_foldarith()`.
38+
test:is(2197541395358679800 % -1005065126.3690554, -606337536,
39+
'FMA in the lj_vm_foldarith() during parsing')
40+
41+
test:is(a % b, -606337536, 'FMA in the VM')
42+
43+
test:done(true)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
local tap = require('tap')
2+
local test = tap.test('lj-918-fma-optimization'):skipcond({
3+
['Test requires JIT enabled'] = not jit.status(),
4+
})
5+
6+
test:plan(3)
7+
8+
local function jit_opt_is_on(flag)
9+
for _, opt in ipairs({jit.status()}) do
10+
if opt == flag then
11+
return true
12+
end
13+
end
14+
return false
15+
end
16+
17+
test:ok(not jit_opt_is_on('fma'), 'FMA is disabled by default')
18+
19+
local ok, _ = pcall(jit.opt.start, '+fma')
20+
21+
test:ok(ok, 'fma flag is recognized')
22+
23+
test:ok(jit_opt_is_on('fma'), 'FMA is enabled after jit.opt.start()')
24+
25+
test:done(true)

0 commit comments

Comments
 (0)