Skip to content

Commit 25ebb5d

Browse files
committed
[autodiff] When asserts are enabled, verify all autodiff compiler generated functions.
This ensures that any invalid SIL generated by these cloners is caught immediately at the source when asserts are enabled improving productivity.
1 parent 887464b commit 25ebb5d

File tree

4 files changed

+38
-5
lines changed

4 files changed

+38
-5
lines changed

include/swift/SILOptimizer/Differentiation/JVPCloner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class JVPCloner final {
4747
/// Performs JVP generation on the empty JVP function. Returns true if any
4848
/// error occurs.
4949
bool run();
50+
51+
SILFunction &getJVP() const;
5052
};
5153

5254
} // end namespace autodiff

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
3232
#include "llvm/ADT/DenseMap.h"
3333

34+
using namespace swift;
35+
using namespace autodiff;
36+
3437
namespace swift {
3538
namespace autodiff {
3639

@@ -380,6 +383,8 @@ class JVPCloner::Implementation final
380383
/// Run JVP generation. Returns true on error.
381384
bool run();
382385

386+
SILFunction &getJVP() const { return *jvp; }
387+
383388
void postProcess(SILInstruction *orig, SILInstruction *cloned) {
384389
if (errorOccurred)
385390
return;
@@ -1727,7 +1732,16 @@ bool JVPCloner::Implementation::run() {
17271732
return errorOccurred;
17281733
}
17291734

1730-
bool JVPCloner::run() { return impl.run(); }
1731-
17321735
} // end namespace autodiff
17331736
} // end namespace swift
1737+
1738+
bool JVPCloner::run() {
1739+
bool foundError = impl.run();
1740+
#ifndef NDEBUG
1741+
if (!foundError)
1742+
getJVP().verify();
1743+
#endif
1744+
return foundError;
1745+
}
1746+
1747+
SILFunction &JVPCloner::getJVP() const { return impl.getJVP(); }

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ class PullbackCloner::Implementation final
138138
SILModule &getModule() const { return getContext().getModule(); }
139139
ASTContext &getASTContext() const { return getPullback().getASTContext(); }
140140
SILFunction &getOriginal() const { return vjpCloner.getOriginal(); }
141-
SILFunction &getPullback() const { return vjpCloner.getPullback(); }
142141
SILDifferentiabilityWitness *getWitness() const {
143142
return vjpCloner.getWitness();
144143
}
@@ -782,6 +781,10 @@ class PullbackCloner::Implementation final
782781
/// parameters.
783782
void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult);
784783

784+
/// Public helper so that our users can get the underlying newly created
785+
/// function.
786+
SILFunction &getPullback() const { return vjpCloner.getPullback(); }
787+
785788
using TrampolineBlockSet = SmallPtrSet<SILBasicBlock *, 4>;
786789

787790
/// Determines the pullback successor block for a given original block and one
@@ -1740,7 +1743,14 @@ PullbackCloner::~PullbackCloner() { delete &impl; }
17401743
// Entry point
17411744
//--------------------------------------------------------------------------//
17421745

1743-
bool PullbackCloner::run() { return impl.run(); }
1746+
bool PullbackCloner::run() {
1747+
bool foundError = impl.run();
1748+
#ifndef NDEBUG
1749+
if (!foundError)
1750+
impl.getPullback().verify();
1751+
#endif
1752+
return foundError;
1753+
}
17441754

17451755
bool PullbackCloner::Implementation::run() {
17461756
PrettyStackTraceSILFunction trace("generating pullback for", &getOriginal());

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,14 @@ bool VJPCloner::Implementation::run() {
10191019
return errorOccurred;
10201020
}
10211021

1022-
bool VJPCloner::run() { return impl.run(); }
1022+
bool VJPCloner::run() {
1023+
bool foundError = impl.run();
1024+
#ifndef NDEBUG
1025+
if (!foundError)
1026+
getVJP().verify();
1027+
#endif
1028+
return foundError;
1029+
}
10231030

10241031
} // end namespace autodiff
10251032
} // end namespace swift

0 commit comments

Comments
 (0)