Skip to content

Commit 845c7d5

Browse files
author
meister
committed
Added another example passing parameters to function
1 parent 402dc3e commit 845c7d5

File tree

1 file changed

+89
-3
lines changed

1 file changed

+89
-3
lines changed

tlm3_dihedral/Dihed.cc

+89-3
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct Dihedral {
7676
extern double __enzyme_autodiff(void*, ...);
7777

7878
void dump_vec(const std::string& msg, float* deriv, size_t num) {
79-
std::cout << std::setw(10) << msg << " ";
79+
std::cout << std::setw(20) << msg << " ";
8080
for ( int i=0; i<num; i++ ) {
8181
std::cout << std::setprecision(4) << std::fixed << std::setw(8) << deriv[i] << " ";
8282
}
@@ -87,6 +87,74 @@ void dump_vec(const std::string& msg, float* deriv, size_t num) {
8787
#define mysqrt(VAR) (std::sqrt(VAR))
8888
#define reciprocal(VAR) (1.0/VAR)
8989

90+
double old_simple_dihedral_energy_gradient( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float* pos, float* deriv ) {
91+
#define USE_EXPLICIT_DECLARES 1
92+
#define DECLARE_FLOAT(VAR) double VAR;
93+
#define Real double
94+
double EraseLinearDihedral;
95+
double x1, y1, z1;
96+
double x2, y2, z2;
97+
double x3, y3, z3;
98+
double x4, y4, z4;
99+
#include "_Dihedral_termDeclares.cc"
100+
#define DIHEDRAL_SET_PARAMETER(VAR) {}
101+
#define DIHEDRAL_SET_POSITION(VAR,IDX,OFFSET) { VAR = pos[IDX+OFFSET]; }
102+
#define DIHEDRAL_ENERGY_ACCUMULATE(VAR) { result += VAR; }
103+
#define DIHEDRAL_FORCE_ACCUMULATE(IDX,OFFSET,VAR) { deriv[IDX+OFFSET] -= VAR; }
104+
double result = 0.0;
105+
double SinNPhi;
106+
double CosNPhi;
107+
bool calcForce = true;
108+
#define DIHEDRAL_CALC_FORCE 1
109+
#include "_Dihedral_termCode.cc"
110+
#undef Real
111+
#undef DECLARE_FLOAT
112+
#undef DIHEDRAL_SET_PARAMETER
113+
#undef DIHEDRAL_SET_POSITION
114+
#undef DIHEDRAL_ENERGY_ACCUMULATE
115+
#undef DIHEDRAL_FORCE_ACCUMULATE
116+
return result;
117+
}
118+
119+
double simple_dihedral_energy( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float* pos ) {
120+
#define USE_EXPLICIT_DECLARES 1
121+
#define DECLARE_FLOAT(VAR) double VAR;
122+
#define Real double
123+
double EraseLinearDihedral;
124+
double x1, y1, z1;
125+
double x2, y2, z2;
126+
double x3, y3, z3;
127+
double x4, y4, z4;
128+
#include "_DihedralEnergy_termDeclares.cc"
129+
#define DIHEDRAL_SET_PARAMETER(VAR) {}
130+
#define DIHEDRAL_SET_POSITION(VAR,IDX,OFFSET) { VAR = pos[IDX+OFFSET]; }
131+
#define DIHEDRAL_ENERGY_ACCUMULATE(VAR) { result += VAR; }
132+
double result = 0.0;
133+
double SinNPhi;
134+
double CosNPhi;
135+
#include "_DihedralEnergy_termCode.cc"
136+
#undef Real
137+
#undef DECLARE_FLOAT
138+
#undef DIHEDRAL_SET_PARAMETER
139+
#undef DIHEDRAL_SET_POSITION
140+
#undef DIHEDRAL_ENERGY_ACCUMULATE
141+
return result;
142+
}
143+
144+
void simple_dihedral_gradient( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float* pos, float* grad ) {
145+
__enzyme_autodiff( (void*)simple_dihedral_energy,
146+
enzyme_const, sinPhase,
147+
enzyme_const, cosPhase,
148+
enzyme_const, V,
149+
enzyme_const, DN,
150+
enzyme_const, IN,
151+
enzyme_const, I1,
152+
enzyme_const, I2,
153+
enzyme_const, I3,
154+
enzyme_const, I4,
155+
enzyme_dup, pos, grad );
156+
}
157+
90158
template <typename Real>
91159
float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos) {
92160
#define USE_EXPLICIT_DECLARES 1
@@ -123,7 +191,7 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
123191
}
124192

125193
template <typename Real>
126-
void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos, float* deriv ) {
194+
__attribute__((noinline)) void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos, float* deriv ) {
127195
__enzyme_autodiff( (void*)dihedral_energy<Real>,
128196
enzyme_const, dihedral_begin,
129197
enzyme_const, dihedral_end,
@@ -174,7 +242,7 @@ Real old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float*
174242
#define DIHEDRAL_SET_PARAMETER(VAR) { VAR = dihedral->VAR; }
175243
#define DIHEDRAL_SET_POSITION(VAR,IDX,OFFSET) { VAR = pos[IDX+OFFSET]; }
176244
#define DIHEDRAL_ENERGY_ACCUMULATE(VAR) { result += VAR; }
177-
#define DIHEDRAL_FORCE_ACCUMULATE(IDX,OFFSET,VAR) { deriv[IDX+OFFSET] += VAR; }
245+
#define DIHEDRAL_FORCE_ACCUMULATE(IDX,OFFSET,VAR) { deriv[IDX+OFFSET] -= VAR; }
178246
int I1;
179247
int I2;
180248
int I3;
@@ -228,18 +296,21 @@ int main( int argc, const char* argv[] ) {
228296
zeroVec(deriv,12);
229297
grad_dihedral_energy<float>( &dihedral[0], &dihedral[1], pos, deriv);
230298
}
299+
dump_vec("new float deriv", deriv, 12);
231300
} else if (arg1 == "new-double") {
232301
std::cout << "New method" << "\n";
233302
// energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
234303
for ( size_t nn = 0; nn<num; nn++ ) {
235304
zeroVec(deriv,12);
236305
grad_dihedral_energy<double>( &dihedral[0], &dihedral[1], pos, deriv);
237306
}
307+
dump_vec("new double deriv", deriv, 12);
238308
} else if ( arg1 == "fin") {
239309
for ( size_t nn = 0; nn<num; nn++ ) {
240310
zeroVec(deriv,12);
241311
finite_diff_dihedral_energy(&dihedral[0], &dihedral[1], pos, deriv);
242312
}
313+
dump_vec("fin deriv", deriv, 12);
243314
} else if (arg1 == "old-float") {
244315
float fenergy;
245316
float fderiv[12];
@@ -264,5 +335,20 @@ int main( int argc, const char* argv[] ) {
264335
}
265336
std::cout << "Energy = " << fenergy << "\n";
266337
dump_vec("deriv", fderiv, 12);
338+
} else if (arg1 == "new-simple") {
339+
float fderiv[12];
340+
for ( size_t nn = 0; nn<num; nn++ ) {
341+
zeroVec(fderiv,12);
342+
simple_dihedral_gradient( 0.0, 1.0, 10.0, 2.0, 2, 0, 3, 6, 9, pos, fderiv );
343+
}
344+
dump_vec("simple deriv via Enzyme", fderiv, 12 );
345+
} else if (arg1 == "old-simple") {
346+
float fderiv[12];
347+
for ( size_t nn = 0; nn<num; nn++ ) {
348+
zeroVec(fderiv,12);
349+
old_simple_dihedral_energy_gradient( 0.0, 1.0, 10.0, 2.0, 2, 0, 3, 6, 9, pos, fderiv );
350+
}
351+
dump_vec("simple deriv old code ", fderiv, 12 );
267352
}
353+
268354
}

0 commit comments

Comments
 (0)