11#include < iostream>
2+ #include < iomanip>
23#include < cstdio>
34#include < string>
45#include < cmath>
56
7+ #define PI 3.1415926535
8+
69#define TENM3 10e-3
710#define MIN (XXX,YYY ) ((XXX<YYY)?XXX:YYY)
811#define MAX (XXX,YYY ) ((XXX<YYY)?YYY:XXX)
@@ -22,6 +25,35 @@ void sinNPhiCosNPhi(int n, float* sinNPhi, float* cosNPhi,
2225 return ;
2326}
2427
28+ template <int N, typename Real>
29+ struct SinCos {
30+ static void sinNPhiCosNPhi (Real& sinNPhi, Real& cosNPhi,
31+ Real sinPhi, Real cosPhi ) {
32+ Real sinNm1Phi;
33+ Real cosNm1Phi;
34+ SinCos<N-1 ,Real>::sinNPhiCosNPhi (sinNm1Phi,cosNm1Phi,sinPhi,cosPhi);
35+ sinNPhi = cosPhi*sinNm1Phi+sinPhi*cosNm1Phi;
36+ cosNPhi = cosPhi*cosNm1Phi-sinPhi*sinNm1Phi;
37+ }
38+ };
39+
40+ template <>
41+ struct SinCos <1 ,float > {
42+ static void sinNPhiCosNPhi (float & sinNPhi, float & cosNPhi,
43+ float sinPhi, float cosPhi ) {
44+ sinNPhi = sinPhi;
45+ cosNPhi = cosPhi;
46+ }
47+ };
48+
49+ template <>
50+ struct SinCos <1 ,double > {
51+ static void sinNPhiCosNPhi (double & sinNPhi, double & cosNPhi,
52+ double sinPhi, double cosPhi ) {
53+ sinNPhi = sinPhi;
54+ cosNPhi = cosPhi;
55+ }
56+ };
2557
2658int enzyme_dup;
2759int enzyme_out;
@@ -44,9 +76,9 @@ struct Dihedral {
4476 extern double __enzyme_autodiff (void *, ...);
4577
4678void dump_vec (const std::string& msg, float * deriv, size_t num) {
47- std::cout << msg << " " ;
79+ std::cout << std::setw ( 20 ) << msg << " " ;
4880 for ( int i=0 ; i<num; i++ ) {
49- std::cout << deriv[i] << " " ;
81+ std::cout << std::setprecision ( 4 ) << std::fixed << std::setw ( 8 ) << deriv[i] << " " ;
5082 }
5183 std::cout << " \n " ;
5284}
@@ -55,19 +87,88 @@ void dump_vec(const std::string& msg, float* deriv, size_t num) {
5587#define mysqrt (VAR ) (std::sqrt(VAR))
5688#define reciprocal (VAR ) (1.0 /VAR)
5789
90+ double old_simple_dihedral_energy_gradient ( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float * pos, float * deriv ) {
91+ #define USE_EXPLICIT_DECLARES 1
92+ #define DECLARE_FLOAT (VAR ) double VAR;
93+ #define Real double
94+ double EraseLinearDihedral;
95+ double x1, y1, z1;
96+ double x2, y2, z2;
97+ double x3, y3, z3;
98+ double x4, y4, z4;
99+ #include " _Dihedral_termDeclares.cc"
100+ #define DIHEDRAL_SET_PARAMETER (VAR ) {}
101+ #define DIHEDRAL_SET_POSITION (VAR,IDX,OFFSET ) { VAR = pos[IDX+OFFSET]; }
102+ #define DIHEDRAL_ENERGY_ACCUMULATE (VAR ) { result += VAR; }
103+ #define DIHEDRAL_FORCE_ACCUMULATE (IDX,OFFSET,VAR ) { deriv[IDX+OFFSET] -= VAR; }
104+ double result = 0.0 ;
105+ double SinNPhi;
106+ double CosNPhi;
107+ bool calcForce = true ;
108+ #define DIHEDRAL_CALC_FORCE 1
109+ #include " _Dihedral_termCode.cc"
110+ #undef Real
111+ #undef DECLARE_FLOAT
112+ #undef DIHEDRAL_SET_PARAMETER
113+ #undef DIHEDRAL_SET_POSITION
114+ #undef DIHEDRAL_ENERGY_ACCUMULATE
115+ #undef DIHEDRAL_FORCE_ACCUMULATE
116+ return result;
117+ }
118+
119+ double simple_dihedral_energy ( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float * pos ) {
120+ #define USE_EXPLICIT_DECLARES 1
121+ #define DECLARE_FLOAT (VAR ) double VAR;
122+ #define Real double
123+ double EraseLinearDihedral;
124+ double x1, y1, z1;
125+ double x2, y2, z2;
126+ double x3, y3, z3;
127+ double x4, y4, z4;
128+ #include " _DihedralEnergy_termDeclares.cc"
129+ #define DIHEDRAL_SET_PARAMETER (VAR ) {}
130+ #define DIHEDRAL_SET_POSITION (VAR,IDX,OFFSET ) { VAR = pos[IDX+OFFSET]; }
131+ #define DIHEDRAL_ENERGY_ACCUMULATE (VAR ) { result += VAR; }
132+ double result = 0.0 ;
133+ double SinNPhi;
134+ double CosNPhi;
135+ #include " _DihedralEnergy_termCode.cc"
136+ #undef Real
137+ #undef DECLARE_FLOAT
138+ #undef DIHEDRAL_SET_PARAMETER
139+ #undef DIHEDRAL_SET_POSITION
140+ #undef DIHEDRAL_ENERGY_ACCUMULATE
141+ return result;
142+ }
143+
144+ void simple_dihedral_gradient ( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float * pos, float * grad ) {
145+ __enzyme_autodiff ( (void *)simple_dihedral_energy,
146+ enzyme_const, sinPhase,
147+ enzyme_const, cosPhase,
148+ enzyme_const, V,
149+ enzyme_const, DN,
150+ enzyme_const, IN,
151+ enzyme_const, I1,
152+ enzyme_const, I2,
153+ enzyme_const, I3,
154+ enzyme_const, I4,
155+ enzyme_dup, pos, grad );
156+ }
157+
158+ template <typename Real>
58159float dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos) {
59160#define USE_EXPLICIT_DECLARES 1
60- #define DECLARE_FLOAT (VAR ) float VAR;
61- float EraseLinearDihedral;
62- float sinPhase;
63- float cosPhase;
64- float V;
65- float DN;
66- float IN;
67- float x1, y1, z1;
68- float x2, y2, z2;
69- float x3, y3, z3;
70- float x4, y4, z4;
161+ #define DECLARE_FLOAT (VAR ) Real VAR;
162+ Real EraseLinearDihedral;
163+ Real sinPhase;
164+ Real cosPhase;
165+ Real V;
166+ Real DN;
167+ int IN;
168+ Real x1, y1, z1;
169+ Real x2, y2, z2;
170+ Real x3, y3, z3;
171+ Real x4, y4, z4;
71172#include " _DihedralEnergy_termDeclares.cc"
72173#define DIHEDRAL_SET_PARAMETER (VAR ) { VAR = dihedral->VAR ; }
73174#define DIHEDRAL_SET_POSITION (VAR,IDX,OFFSET ) { VAR = pos[IDX+OFFSET]; }
@@ -76,9 +177,9 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
76177 int I2;
77178 int I3;
78179 int I4;
79- float result = 0.0 ;
80- float SinNPhi;
81- float CosNPhi;
180+ Real result = 0.0 ;
181+ Real SinNPhi;
182+ Real CosNPhi;
82183 for ( auto dihedral = dihedral_begin; dihedral < dihedral_end; dihedral++ ) {
83184#include " _DihedralEnergy_termCode.cc"
84185 }
@@ -89,8 +190,9 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
89190 return result;
90191}
91192
92- void grad_dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv ) {
93- __enzyme_autodiff ( (void *)dihedral_energy,
193+ template <typename Real>
194+ __attribute__ ((noinline)) void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv ) {
195+ __enzyme_autodiff ( (void *)dihedral_energy<Real>,
94196 enzyme_const, dihedral_begin,
95197 enzyme_const, dihedral_end,
96198 enzyme_dup, pos, deriv );
@@ -99,7 +201,7 @@ void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, floa
99201void finite_diff_dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv)
100202{
101203 // Calculate the energy using the dihedral_energy function
102- float energy_old = dihedral_energy (dihedral_begin, dihedral_end, pos);
204+ float energy_old = dihedral_energy< float > (dihedral_begin, dihedral_end, pos);
103205
104206 // Iterate over each position in the pos array
105207
@@ -115,41 +217,42 @@ void finite_diff_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_en
115217 posp[i] = pos[i] + TENM3;
116218 posm[i] = pos[i] - TENM3;
117219
118- float energy_new_p = dihedral_energy (dihedral_begin, dihedral_end, posp);
119- float energy_new_m = dihedral_energy (dihedral_begin, dihedral_end, posm);
220+ float energy_new_p = dihedral_energy< float > (dihedral_begin, dihedral_end, posp);
221+ float energy_new_m = dihedral_energy< float > (dihedral_begin, dihedral_end, posm);
120222
121223 // Calculate the derivative using finite differences
122224 deriv[i] = (energy_new_p - energy_new_m) / (2.0 *TENM3);
123225 }
124226}
125227
126- float old_dihedral_energy (Dihedral* dihedral_begin, Dihedral*dihedral_end, float * pos, float * deriv) {
228+ template <typename Real>
229+ Real old_dihedral_energy (Dihedral* dihedral_begin, Dihedral*dihedral_end, float * pos, float * deriv) {
127230#define USE_EXPLICIT_DECLARES 1
128- #define DECLARE_FLOAT (VAR ) float VAR;
129- float sinPhase;
130- float cosPhase;
131- float V;
132- float DN;
133- float IN;
134- float x1, y1, z1;
135- float x2, y2, z2;
136- float x3, y3, z3;
137- float x4, y4, z4;
231+ #define DECLARE_FLOAT (VAR ) Real VAR;
232+ Real sinPhase;
233+ Real cosPhase;
234+ Real V;
235+ Real DN;
236+ int IN;
237+ Real x1, y1, z1;
238+ Real x2, y2, z2;
239+ Real x3, y3, z3;
240+ Real x4, y4, z4;
138241#include " _Dihedral_termDeclares.cc"
139242#define DIHEDRAL_SET_PARAMETER (VAR ) { VAR = dihedral->VAR ; }
140243#define DIHEDRAL_SET_POSITION (VAR,IDX,OFFSET ) { VAR = pos[IDX+OFFSET]; }
141244#define DIHEDRAL_ENERGY_ACCUMULATE (VAR ) { result += VAR; }
142- #define DIHEDRAL_FORCE_ACCUMULATE (IDX,OFFSET,VAR ) { deriv[IDX+OFFSET] + = VAR; }
245+ #define DIHEDRAL_FORCE_ACCUMULATE (IDX,OFFSET,VAR ) { deriv[IDX+OFFSET] - = VAR; }
143246 int I1;
144247 int I2;
145248 int I3;
146249 int I4;
147- float SinNPhi;
148- float CosNPhi;
149- float result = 0.0 ;
250+ Real SinNPhi;
251+ Real CosNPhi;
252+ Real result = 0.0 ;
150253 bool calcForce = true ;
151254#define DIHEDRAL_CALC_FORCE 1
152- float EraseLinearDihedral;
255+ Real EraseLinearDihedral;
153256 for ( auto dihedral = dihedral_begin; dihedral<dihedral_end; dihedral++ ) {
154257#include " _Dihedral_termCode.cc"
155258 }
@@ -165,56 +268,87 @@ void zeroVec(float* deriv, size_t num) {
165268 deriv[i] = 0.0 ;
166269 }
167270}
271+ void copyVec (float * result, float * val, size_t num) {
272+ for (int i=0 ; i<num; i++ ) {
273+ result[i] = val[i];
274+ }
275+ }
276+
277+
168278int main ( int argc, const char * argv[] ) {
169279 float ANG = 20.0 *0.0174533 ;
170280 float pos[12 ] = {0.0 , 0.0 , 1.0 ,
171- 0.0 , 0.0 , 0.0 ,
172- 1.0 , 0.0 , 0.0 ,
173- 1.0 , (float )-sin (ANG), (float )cos (ANG) };
281+ 0.0 , 0.0 , 0.0 ,
282+ 1.0 , 0.0 , 0.0 ,
283+ 1.0 , (float )-sin (ANG), (float )cos (ANG) };
174284 float deriv[12 ];
175285 Dihedral dihedral[] = { {0.0 , 1.0 , 10.0 , 2.0 , 2 , 0 , 3 , 6 , 9 } };
176286
177287 dump_vec (" pos" , pos, 12 );
178288 float energy = 0.0 ;
179289 std::string arg1 (argv[1 ]);
180- int donew = 0 ;
181-
182- if (arg1 == " new" )
183- donew = 1 ;
184- else if (arg1 == " fin" )
185- donew = 2 ;
186- else
187- donew = 3 ;
188290
189291 size_t num = atoi (argv[2 ]);
190- if (donew == 1 )// for new
191- {
292+ if (arg1 == " new-float" ) {
192293 std::cout << " New method" << " \n " ;
193294 // energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
194295 for ( size_t nn = 0 ; nn<num; nn++ ) {
195296 zeroVec (deriv,12 );
196- grad_dihedral_energy ( &dihedral[0 ], &dihedral[1 ], pos, deriv);
297+ grad_dihedral_energy< float > ( &dihedral[0 ], &dihedral[1 ], pos, deriv);
197298 }
198- }
199-
200- else if (donew == 2 )// for finite diff
201- {
202- for ( size_t nn = 0 ; nn<num; nn++ ) {
203- zeroVec (deriv,12 );
204- finite_diff_dihedral_energy (&dihedral[0 ], &dihedral[1 ], pos, deriv);
205- }
299+ dump_vec (" new float deriv" , deriv, 12 );
300+ } else if (arg1 == " new-double" ) {
301+ std::cout << " New method" << " \n " ;
302+ // energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
303+ for ( size_t nn = 0 ; nn<num; nn++ ) {
304+ zeroVec (deriv,12 );
305+ grad_dihedral_energy<double >( &dihedral[0 ], &dihedral[1 ], pos, deriv);
306+ }
307+ dump_vec (" new double deriv" , deriv, 12 );
308+ } else if ( arg1 == " fin" ) {
309+ for ( size_t nn = 0 ; nn<num; nn++ ) {
310+ zeroVec (deriv,12 );
311+ finite_diff_dihedral_energy (&dihedral[0 ], &dihedral[1 ], pos, deriv);
206312 }
207-
208-
209- else // for old
210- {
313+ dump_vec (" fin deriv" , deriv, 12 );
314+ } else if (arg1 == " old-float" ) {
315+ float fenergy;
316+ float fderiv[12 ];
317+ float fpos[12 ];
318+ copyVec (fpos,pos,12 );
211319 std::cout << " Old method" << " \n " ;
212320 for ( size_t nn = 0 ; nn<num; nn++ ) {
213- zeroVec (deriv,12 );
214- energy = old_dihedral_energy ( &dihedral[0 ], &dihedral[1 ], pos, deriv );
321+ zeroVec (fderiv,12 );
322+ energy = old_dihedral_energy<float >( &dihedral[0 ], &dihedral[1 ], fpos, fderiv );
323+ }
324+ std::cout << " Energy = " << fenergy << " \n " ;
325+ dump_vec (" deriv" , fderiv, 12 );
326+ } else if (arg1 == " old-double" ) {
327+ float fenergy;
328+ float fderiv[12 ];
329+ float fpos[12 ];
330+ copyVec (fpos,pos,12 );
331+ std::cout << " Old method" << " \n " ;
332+ for ( size_t nn = 0 ; nn<num; nn++ ) {
333+ zeroVec (fderiv,12 );
334+ energy = old_dihedral_energy<double >( &dihedral[0 ], &dihedral[1 ], fpos, fderiv );
335+ }
336+ std::cout << " Energy = " << fenergy << " \n " ;
337+ dump_vec (" deriv" , fderiv, 12 );
338+ } else if (arg1 == " new-simple" ) {
339+ float fderiv[12 ];
340+ for ( size_t nn = 0 ; nn<num; nn++ ) {
341+ zeroVec (fderiv,12 );
342+ simple_dihedral_gradient ( 0.0 , 1.0 , 10.0 , 2.0 , 2 , 0 , 3 , 6 , 9 , pos, fderiv );
343+ }
344+ dump_vec (" simple deriv via Enzyme" , fderiv, 12 );
345+ } else if (arg1 == " old-simple" ) {
346+ float fderiv[12 ];
347+ for ( size_t nn = 0 ; nn<num; nn++ ) {
348+ zeroVec (fderiv,12 );
349+ old_simple_dihedral_energy_gradient ( 0.0 , 1.0 , 10.0 , 2.0 , 2 , 0 , 3 , 6 , 9 , pos, fderiv );
215350 }
351+ dump_vec (" simple deriv old code " , fderiv, 12 );
216352 }
217- std::cout << " Energy = " << energy << " \n " ;
218- dump_vec (" deriv" , deriv, 12 );
219353
220354}
0 commit comments