Skip to content

Commit 933247d

Browse files
authored
[SimplifyLibCalls] Merge sqrt into the power of exp (#79146)
Under fast-math flags it's possible to convert `sqrt(exp(X)) `into `exp(X * 0.5)`. I suppose that this transformation is always profitable. This is similar to the optimization existing in GCC.
1 parent 0b62218 commit 933247d

File tree

3 files changed

+188
-0
lines changed

3 files changed

+188
-0
lines changed

llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h

+1
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ class LibCallSimplifier {
201201
Value *optimizeFMinFMax(CallInst *CI, IRBuilderBase &B);
202202
Value *optimizeLog(CallInst *CI, IRBuilderBase &B);
203203
Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B);
204+
Value *mergeSqrtToExp(CallInst *CI, IRBuilderBase &B);
204205
Value *optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B);
205206
Value *optimizeTan(CallInst *CI, IRBuilderBase &B);
206207
// Wrapper for all floating point library call optimizations

llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

+67
Original file line numberDiff line numberDiff line change
@@ -2545,6 +2545,70 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) {
25452545
return Ret;
25462546
}
25472547

2548+
// sqrt(exp(X)) -> exp(X * 0.5)
2549+
Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) {
2550+
if (!CI->hasAllowReassoc())
2551+
return nullptr;
2552+
2553+
Function *SqrtFn = CI->getCalledFunction();
2554+
CallInst *Arg = dyn_cast<CallInst>(CI->getArgOperand(0));
2555+
if (!Arg || !Arg->hasAllowReassoc() || !Arg->hasOneUse())
2556+
return nullptr;
2557+
Intrinsic::ID ArgID = Arg->getIntrinsicID();
2558+
LibFunc ArgLb = NotLibFunc;
2559+
TLI->getLibFunc(*Arg, ArgLb);
2560+
2561+
LibFunc SqrtLb, ExpLb, Exp2Lb, Exp10Lb;
2562+
2563+
if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb))
2564+
switch (SqrtLb) {
2565+
case LibFunc_sqrtf:
2566+
ExpLb = LibFunc_expf;
2567+
Exp2Lb = LibFunc_exp2f;
2568+
Exp10Lb = LibFunc_exp10f;
2569+
break;
2570+
case LibFunc_sqrt:
2571+
ExpLb = LibFunc_exp;
2572+
Exp2Lb = LibFunc_exp2;
2573+
Exp10Lb = LibFunc_exp10;
2574+
break;
2575+
case LibFunc_sqrtl:
2576+
ExpLb = LibFunc_expl;
2577+
Exp2Lb = LibFunc_exp2l;
2578+
Exp10Lb = LibFunc_exp10l;
2579+
break;
2580+
default:
2581+
return nullptr;
2582+
}
2583+
else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) {
2584+
if (CI->getType()->getScalarType()->isFloatTy()) {
2585+
ExpLb = LibFunc_expf;
2586+
Exp2Lb = LibFunc_exp2f;
2587+
Exp10Lb = LibFunc_exp10f;
2588+
} else if (CI->getType()->getScalarType()->isDoubleTy()) {
2589+
ExpLb = LibFunc_exp;
2590+
Exp2Lb = LibFunc_exp2;
2591+
Exp10Lb = LibFunc_exp10;
2592+
} else
2593+
return nullptr;
2594+
} else
2595+
return nullptr;
2596+
2597+
if (ArgLb != ExpLb && ArgLb != Exp2Lb && ArgLb != Exp10Lb &&
2598+
ArgID != Intrinsic::exp && ArgID != Intrinsic::exp2)
2599+
return nullptr;
2600+
2601+
IRBuilderBase::InsertPointGuard Guard(B);
2602+
B.SetInsertPoint(Arg);
2603+
auto *ExpOperand = Arg->getOperand(0);
2604+
auto *FMul =
2605+
B.CreateFMulFMF(ExpOperand, ConstantFP::get(ExpOperand->getType(), 0.5),
2606+
CI, "merged.sqrt");
2607+
2608+
Arg->setOperand(0, FMul);
2609+
return Arg;
2610+
}
2611+
25482612
Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
25492613
Module *M = CI->getModule();
25502614
Function *Callee = CI->getCalledFunction();
@@ -2557,6 +2621,9 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
25572621
Callee->getIntrinsicID() == Intrinsic::sqrt))
25582622
Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
25592623

2624+
if (Value *Opt = mergeSqrtToExp(CI, B))
2625+
return Opt;
2626+
25602627
if (!CI->isFast())
25612628
return Ret;
25622629

llvm/test/Transforms/InstCombine/sqrt.ll

+120
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,127 @@ define float @sqrt_call_fabs_f32(float %x) {
8888
ret float %sqrt
8989
}
9090

91+
define double @sqrt_exp(double %x) {
92+
; CHECK-LABEL: @sqrt_exp(
93+
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
94+
; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[MERGED_SQRT]])
95+
; CHECK-NEXT: ret double [[E]]
96+
;
97+
%e = call reassoc double @llvm.exp.f64(double %x)
98+
%res = call reassoc double @llvm.sqrt.f64(double %e)
99+
ret double %res
100+
}
101+
102+
define double @sqrt_exp_2(double %x) {
103+
; CHECK-LABEL: @sqrt_exp_2(
104+
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
105+
; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp(double [[MERGED_SQRT]])
106+
; CHECK-NEXT: ret double [[E]]
107+
;
108+
%e = call reassoc double @exp(double %x)
109+
%res = call reassoc double @sqrt(double %e)
110+
ret double %res
111+
}
112+
113+
define double @sqrt_exp2(double %x) {
114+
; CHECK-LABEL: @sqrt_exp2(
115+
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
116+
; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp2(double [[MERGED_SQRT]])
117+
; CHECK-NEXT: ret double [[E]]
118+
;
119+
%e = call reassoc double @exp2(double %x)
120+
%res = call reassoc double @sqrt(double %e)
121+
ret double %res
122+
}
123+
124+
define double @sqrt_exp10(double %x) {
125+
; CHECK-LABEL: @sqrt_exp10(
126+
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
127+
; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp10(double [[MERGED_SQRT]])
128+
; CHECK-NEXT: ret double [[E]]
129+
;
130+
%e = call reassoc double @exp10(double %x)
131+
%res = call reassoc double @sqrt(double %e)
132+
ret double %res
133+
}
134+
135+
; Negative test
136+
define double @sqrt_exp_nofast_1(double %x) {
137+
; CHECK-LABEL: @sqrt_exp_nofast_1(
138+
; CHECK-NEXT: [[E:%.*]] = call double @llvm.exp.f64(double [[X:%.*]])
139+
; CHECK-NEXT: [[RES:%.*]] = call reassoc double @llvm.sqrt.f64(double [[E]])
140+
; CHECK-NEXT: ret double [[RES]]
141+
;
142+
%e = call double @llvm.exp.f64(double %x)
143+
%res = call reassoc double @llvm.sqrt.f64(double %e)
144+
ret double %res
145+
}
146+
147+
; Negative test
148+
define double @sqrt_exp_nofast_2(double %x) {
149+
; CHECK-LABEL: @sqrt_exp_nofast_2(
150+
; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[X:%.*]])
151+
; CHECK-NEXT: [[RES:%.*]] = call double @llvm.sqrt.f64(double [[E]])
152+
; CHECK-NEXT: ret double [[RES]]
153+
;
154+
%e = call reassoc double @llvm.exp.f64(double %x)
155+
%res = call double @llvm.sqrt.f64(double %e)
156+
ret double %res
157+
}
158+
159+
define double @sqrt_exp_merge_constant(double %x, double %y) {
160+
; CHECK-LABEL: @sqrt_exp_merge_constant(
161+
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc nsz double [[X:%.*]], 5.000000e+00
162+
; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[MERGED_SQRT]])
163+
; CHECK-NEXT: ret double [[E]]
164+
;
165+
%mul = fmul reassoc nsz double %x, 10.0
166+
%e = call reassoc double @llvm.exp.f64(double %mul)
167+
%res = call reassoc nsz double @llvm.sqrt.f64(double %e)
168+
ret double %res
169+
}
170+
171+
define double @sqrt_exp_intr_and_libcall(double %x) {
172+
; CHECK-LABEL: @sqrt_exp_intr_and_libcall(
173+
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
174+
; CHECK-NEXT: [[E:%.*]] = call reassoc double @exp(double [[MERGED_SQRT]])
175+
; CHECK-NEXT: ret double [[E]]
176+
;
177+
%e = call reassoc double @exp(double %x)
178+
%res = call reassoc double @llvm.sqrt.f64(double %e)
179+
ret double %res
180+
}
181+
182+
define double @sqrt_exp_intr_and_libcall_2(double %x) {
183+
; CHECK-LABEL: @sqrt_exp_intr_and_libcall_2(
184+
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc double [[X:%.*]], 5.000000e-01
185+
; CHECK-NEXT: [[E:%.*]] = call reassoc double @llvm.exp.f64(double [[MERGED_SQRT]])
186+
; CHECK-NEXT: ret double [[E]]
187+
;
188+
%e = call reassoc double @llvm.exp.f64(double %x)
189+
%res = call reassoc double @sqrt(double %e)
190+
ret double %res
191+
}
192+
193+
define <2 x float> @sqrt_exp_vec(<2 x float> %x) {
194+
; CHECK-LABEL: @sqrt_exp_vec(
195+
; CHECK-NEXT: [[MERGED_SQRT:%.*]] = fmul reassoc <2 x float> [[X:%.*]], <float 5.000000e-01, float 5.000000e-01>
196+
; CHECK-NEXT: [[E:%.*]] = call reassoc <2 x float> @llvm.exp.v2f32(<2 x float> [[MERGED_SQRT]])
197+
; CHECK-NEXT: ret <2 x float> [[E]]
198+
;
199+
%e = call reassoc <2 x float> @llvm.exp.v2f32(<2 x float> %x)
200+
%res = call reassoc <2 x float> @llvm.sqrt.v2f32(<2 x float> %e)
201+
ret <2 x float> %res
202+
}
203+
91204
declare i32 @foo(double)
92205
declare double @sqrt(double) readnone
93206
declare float @sqrtf(float)
94207
declare float @llvm.fabs.f32(float)
208+
declare double @llvm.exp.f64(double)
209+
declare double @llvm.sqrt.f64(double)
210+
declare double @exp(double)
211+
declare double @exp2(double)
212+
declare double @exp10(double)
213+
declare <2 x float> @llvm.exp.v2f32(<2 x float>)
214+
declare <2 x float> @llvm.sqrt.v2f32(<2 x float>)

0 commit comments

Comments
 (0)