1
1
#include < iostream>
2
+ #include < iomanip>
2
3
#include < cstdio>
3
4
#include < string>
4
5
#include < cmath>
5
6
7
+ #define PI 3.1415926535
8
+
6
9
#define TENM3 10e-3
7
10
#define MIN (XXX,YYY ) ((XXX<YYY)?XXX:YYY)
8
11
#define MAX (XXX,YYY ) ((XXX<YYY)?YYY:XXX)
@@ -22,6 +25,35 @@ void sinNPhiCosNPhi(int n, float* sinNPhi, float* cosNPhi,
22
25
return ;
23
26
}
24
27
28
+ template <int N, typename Real>
29
+ struct SinCos {
30
+ static void sinNPhiCosNPhi (Real& sinNPhi, Real& cosNPhi,
31
+ Real sinPhi, Real cosPhi ) {
32
+ Real sinNm1Phi;
33
+ Real cosNm1Phi;
34
+ SinCos<N-1 ,Real>::sinNPhiCosNPhi (sinNm1Phi,cosNm1Phi,sinPhi,cosPhi);
35
+ sinNPhi = cosPhi*sinNm1Phi+sinPhi*cosNm1Phi;
36
+ cosNPhi = cosPhi*cosNm1Phi-sinPhi*sinNm1Phi;
37
+ }
38
+ };
39
+
40
+ template <>
41
+ struct SinCos <1 ,float > {
42
+ static void sinNPhiCosNPhi (float & sinNPhi, float & cosNPhi,
43
+ float sinPhi, float cosPhi ) {
44
+ sinNPhi = sinPhi;
45
+ cosNPhi = cosPhi;
46
+ }
47
+ };
48
+
49
+ template <>
50
+ struct SinCos <1 ,double > {
51
+ static void sinNPhiCosNPhi (double & sinNPhi, double & cosNPhi,
52
+ double sinPhi, double cosPhi ) {
53
+ sinNPhi = sinPhi;
54
+ cosNPhi = cosPhi;
55
+ }
56
+ };
25
57
26
58
int enzyme_dup;
27
59
int enzyme_out;
@@ -44,9 +76,9 @@ struct Dihedral {
44
76
extern double __enzyme_autodiff (void *, ...);
45
77
46
78
void dump_vec (const std::string& msg, float * deriv, size_t num) {
47
- std::cout << msg << " " ;
79
+ std::cout << std::setw ( 20 ) << msg << " " ;
48
80
for ( int i=0 ; i<num; i++ ) {
49
- std::cout << deriv[i] << " " ;
81
+ std::cout << std::setprecision ( 4 ) << std::fixed << std::setw ( 8 ) << deriv[i] << " " ;
50
82
}
51
83
std::cout << " \n " ;
52
84
}
@@ -55,19 +87,88 @@ void dump_vec(const std::string& msg, float* deriv, size_t num) {
55
87
#define mysqrt (VAR ) (std::sqrt(VAR))
56
88
#define reciprocal (VAR ) (1.0 /VAR)
57
89
90
+ double old_simple_dihedral_energy_gradient ( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float * pos, float * deriv ) {
91
+ #define USE_EXPLICIT_DECLARES 1
92
+ #define DECLARE_FLOAT (VAR ) double VAR;
93
+ #define Real double
94
+ double EraseLinearDihedral;
95
+ double x1, y1 , z1;
96
+ double x2, y2, z2;
97
+ double x3, y3, z3;
98
+ double x4, y4, z4;
99
+ #include " _Dihedral_termDeclares.cc"
100
+ #define DIHEDRAL_SET_PARAMETER (VAR ) {}
101
+ #define DIHEDRAL_SET_POSITION (VAR,IDX,OFFSET ) { VAR = pos[IDX+OFFSET]; }
102
+ #define DIHEDRAL_ENERGY_ACCUMULATE (VAR ) { result += VAR; }
103
+ #define DIHEDRAL_FORCE_ACCUMULATE (IDX,OFFSET,VAR ) { deriv[IDX+OFFSET] -= VAR; }
104
+ double result = 0.0 ;
105
+ double SinNPhi;
106
+ double CosNPhi;
107
+ bool calcForce = true ;
108
+ #define DIHEDRAL_CALC_FORCE 1
109
+ #include " _Dihedral_termCode.cc"
110
+ #undef Real
111
+ #undef DECLARE_FLOAT
112
+ #undef DIHEDRAL_SET_PARAMETER
113
+ #undef DIHEDRAL_SET_POSITION
114
+ #undef DIHEDRAL_ENERGY_ACCUMULATE
115
+ #undef DIHEDRAL_FORCE_ACCUMULATE
116
+ return result;
117
+ }
118
+
119
+ double simple_dihedral_energy ( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float * pos ) {
120
+ #define USE_EXPLICIT_DECLARES 1
121
+ #define DECLARE_FLOAT (VAR ) double VAR;
122
+ #define Real double
123
+ double EraseLinearDihedral;
124
+ double x1, y1 , z1;
125
+ double x2, y2, z2;
126
+ double x3, y3, z3;
127
+ double x4, y4, z4;
128
+ #include " _DihedralEnergy_termDeclares.cc"
129
+ #define DIHEDRAL_SET_PARAMETER (VAR ) {}
130
+ #define DIHEDRAL_SET_POSITION (VAR,IDX,OFFSET ) { VAR = pos[IDX+OFFSET]; }
131
+ #define DIHEDRAL_ENERGY_ACCUMULATE (VAR ) { result += VAR; }
132
+ double result = 0.0 ;
133
+ double SinNPhi;
134
+ double CosNPhi;
135
+ #include " _DihedralEnergy_termCode.cc"
136
+ #undef Real
137
+ #undef DECLARE_FLOAT
138
+ #undef DIHEDRAL_SET_PARAMETER
139
+ #undef DIHEDRAL_SET_POSITION
140
+ #undef DIHEDRAL_ENERGY_ACCUMULATE
141
+ return result;
142
+ }
143
+
144
+ void simple_dihedral_gradient ( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float * pos, float * grad ) {
145
+ __enzyme_autodiff ( (void *)simple_dihedral_energy,
146
+ enzyme_const, sinPhase,
147
+ enzyme_const, cosPhase,
148
+ enzyme_const, V,
149
+ enzyme_const, DN,
150
+ enzyme_const, IN,
151
+ enzyme_const, I1,
152
+ enzyme_const, I2,
153
+ enzyme_const, I3,
154
+ enzyme_const, I4,
155
+ enzyme_dup, pos, grad );
156
+ }
157
+
158
+ template <typename Real>
58
159
float dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos) {
59
160
#define USE_EXPLICIT_DECLARES 1
60
- #define DECLARE_FLOAT (VAR ) float VAR;
61
- float EraseLinearDihedral;
62
- float sinPhase;
63
- float cosPhase;
64
- float V;
65
- float DN;
66
- float IN;
67
- float x1, y1 , z1;
68
- float x2, y2, z2;
69
- float x3, y3, z3;
70
- float x4, y4, z4;
161
+ #define DECLARE_FLOAT (VAR ) Real VAR;
162
+ Real EraseLinearDihedral;
163
+ Real sinPhase;
164
+ Real cosPhase;
165
+ Real V;
166
+ Real DN;
167
+ int IN;
168
+ Real x1, y1 , z1;
169
+ Real x2, y2, z2;
170
+ Real x3, y3, z3;
171
+ Real x4, y4, z4;
71
172
#include " _DihedralEnergy_termDeclares.cc"
72
173
#define DIHEDRAL_SET_PARAMETER (VAR ) { VAR = dihedral->VAR ; }
73
174
#define DIHEDRAL_SET_POSITION (VAR,IDX,OFFSET ) { VAR = pos[IDX+OFFSET]; }
@@ -76,9 +177,9 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
76
177
int I2;
77
178
int I3;
78
179
int I4;
79
- float result = 0.0 ;
80
- float SinNPhi;
81
- float CosNPhi;
180
+ Real result = 0.0 ;
181
+ Real SinNPhi;
182
+ Real CosNPhi;
82
183
for ( auto dihedral = dihedral_begin; dihedral < dihedral_end; dihedral++ ) {
83
184
#include " _DihedralEnergy_termCode.cc"
84
185
}
@@ -89,8 +190,9 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
89
190
return result;
90
191
}
91
192
92
- void grad_dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv ) {
93
- __enzyme_autodiff ( (void *)dihedral_energy,
193
+ template <typename Real>
194
+ __attribute__ ((noinline)) void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv ) {
195
+ __enzyme_autodiff ( (void *)dihedral_energy<Real>,
94
196
enzyme_const, dihedral_begin,
95
197
enzyme_const, dihedral_end,
96
198
enzyme_dup, pos, deriv );
@@ -99,7 +201,7 @@ void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, floa
99
201
void finite_diff_dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv)
100
202
{
101
203
// Calculate the energy using the dihedral_energy function
102
- float energy_old = dihedral_energy (dihedral_begin, dihedral_end, pos);
204
+ float energy_old = dihedral_energy< float > (dihedral_begin, dihedral_end, pos);
103
205
104
206
// Iterate over each position in the pos array
105
207
@@ -115,41 +217,42 @@ void finite_diff_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_en
115
217
posp[i] = pos[i] + TENM3;
116
218
posm[i] = pos[i] - TENM3;
117
219
118
- float energy_new_p = dihedral_energy (dihedral_begin, dihedral_end, posp);
119
- float energy_new_m = dihedral_energy (dihedral_begin, dihedral_end, posm);
220
+ float energy_new_p = dihedral_energy< float > (dihedral_begin, dihedral_end, posp);
221
+ float energy_new_m = dihedral_energy< float > (dihedral_begin, dihedral_end, posm);
120
222
121
223
// Calculate the derivative using finite differences
122
224
deriv[i] = (energy_new_p - energy_new_m) / (2.0 *TENM3);
123
225
}
124
226
}
125
227
126
- float old_dihedral_energy (Dihedral* dihedral_begin, Dihedral*dihedral_end, float * pos, float * deriv) {
228
+ template <typename Real>
229
+ Real old_dihedral_energy (Dihedral* dihedral_begin, Dihedral*dihedral_end, float * pos, float * deriv) {
127
230
#define USE_EXPLICIT_DECLARES 1
128
- #define DECLARE_FLOAT (VAR ) float VAR;
129
- float sinPhase;
130
- float cosPhase;
131
- float V;
132
- float DN;
133
- float IN;
134
- float x1, y1 , z1;
135
- float x2, y2, z2;
136
- float x3, y3, z3;
137
- float x4, y4, z4;
231
+ #define DECLARE_FLOAT (VAR ) Real VAR;
232
+ Real sinPhase;
233
+ Real cosPhase;
234
+ Real V;
235
+ Real DN;
236
+ int IN;
237
+ Real x1, y1 , z1;
238
+ Real x2, y2, z2;
239
+ Real x3, y3, z3;
240
+ Real x4, y4, z4;
138
241
#include " _Dihedral_termDeclares.cc"
139
242
#define DIHEDRAL_SET_PARAMETER (VAR ) { VAR = dihedral->VAR ; }
140
243
#define DIHEDRAL_SET_POSITION (VAR,IDX,OFFSET ) { VAR = pos[IDX+OFFSET]; }
141
244
#define DIHEDRAL_ENERGY_ACCUMULATE (VAR ) { result += VAR; }
142
- #define DIHEDRAL_FORCE_ACCUMULATE (IDX,OFFSET,VAR ) { deriv[IDX+OFFSET] + = VAR; }
245
+ #define DIHEDRAL_FORCE_ACCUMULATE (IDX,OFFSET,VAR ) { deriv[IDX+OFFSET] - = VAR; }
143
246
int I1;
144
247
int I2;
145
248
int I3;
146
249
int I4;
147
- float SinNPhi;
148
- float CosNPhi;
149
- float result = 0.0 ;
250
+ Real SinNPhi;
251
+ Real CosNPhi;
252
+ Real result = 0.0 ;
150
253
bool calcForce = true ;
151
254
#define DIHEDRAL_CALC_FORCE 1
152
- float EraseLinearDihedral;
255
+ Real EraseLinearDihedral;
153
256
for ( auto dihedral = dihedral_begin; dihedral<dihedral_end; dihedral++ ) {
154
257
#include " _Dihedral_termCode.cc"
155
258
}
@@ -165,56 +268,87 @@ void zeroVec(float* deriv, size_t num) {
165
268
deriv[i] = 0.0 ;
166
269
}
167
270
}
271
+ void copyVec (float * result, float * val, size_t num) {
272
+ for (int i=0 ; i<num; i++ ) {
273
+ result[i] = val[i];
274
+ }
275
+ }
276
+
277
+
168
278
int main ( int argc, const char * argv[] ) {
169
279
float ANG = 20.0 *0.0174533 ;
170
280
float pos[12 ] = {0.0 , 0.0 , 1.0 ,
171
- 0.0 , 0.0 , 0.0 ,
172
- 1.0 , 0.0 , 0.0 ,
173
- 1.0 , (float )-sin (ANG), (float )cos (ANG) };
281
+ 0.0 , 0.0 , 0.0 ,
282
+ 1.0 , 0.0 , 0.0 ,
283
+ 1.0 , (float )-sin (ANG), (float )cos (ANG) };
174
284
float deriv[12 ];
175
285
Dihedral dihedral[] = { {0.0 , 1.0 , 10.0 , 2.0 , 2 , 0 , 3 , 6 , 9 } };
176
286
177
287
dump_vec (" pos" , pos, 12 );
178
288
float energy = 0.0 ;
179
289
std::string arg1 (argv[1 ]);
180
- int donew = 0 ;
181
-
182
- if (arg1 == " new" )
183
- donew = 1 ;
184
- else if (arg1 == " fin" )
185
- donew = 2 ;
186
- else
187
- donew = 3 ;
188
290
189
291
size_t num = atoi (argv[2 ]);
190
- if (donew == 1 )// for new
191
- {
292
+ if (arg1 == " new-float" ) {
192
293
std::cout << " New method" << " \n " ;
193
294
// energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
194
295
for ( size_t nn = 0 ; nn<num; nn++ ) {
195
296
zeroVec (deriv,12 );
196
- grad_dihedral_energy ( &dihedral[0 ], &dihedral[1 ], pos, deriv);
297
+ grad_dihedral_energy< float > ( &dihedral[0 ], &dihedral[1 ], pos, deriv);
197
298
}
198
- }
199
-
200
- else if (donew == 2 )// for finite diff
201
- {
202
- for ( size_t nn = 0 ; nn<num; nn++ ) {
203
- zeroVec (deriv,12 );
204
- finite_diff_dihedral_energy (&dihedral[0 ], &dihedral[1 ], pos, deriv);
205
- }
299
+ dump_vec (" new float deriv" , deriv, 12 );
300
+ } else if (arg1 == " new-double" ) {
301
+ std::cout << " New method" << " \n " ;
302
+ // energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
303
+ for ( size_t nn = 0 ; nn<num; nn++ ) {
304
+ zeroVec (deriv,12 );
305
+ grad_dihedral_energy<double >( &dihedral[0 ], &dihedral[1 ], pos, deriv);
306
+ }
307
+ dump_vec (" new double deriv" , deriv, 12 );
308
+ } else if ( arg1 == " fin" ) {
309
+ for ( size_t nn = 0 ; nn<num; nn++ ) {
310
+ zeroVec (deriv,12 );
311
+ finite_diff_dihedral_energy (&dihedral[0 ], &dihedral[1 ], pos, deriv);
206
312
}
207
-
208
-
209
- else // for old
210
- {
313
+ dump_vec (" fin deriv" , deriv, 12 );
314
+ } else if (arg1 == " old-float" ) {
315
+ float fenergy;
316
+ float fderiv[12 ];
317
+ float fpos[12 ];
318
+ copyVec (fpos,pos,12 );
211
319
std::cout << " Old method" << " \n " ;
212
320
for ( size_t nn = 0 ; nn<num; nn++ ) {
213
- zeroVec (deriv,12 );
214
- energy = old_dihedral_energy ( &dihedral[0 ], &dihedral[1 ], pos, deriv );
321
+ zeroVec (fderiv,12 );
322
+ energy = old_dihedral_energy<float >( &dihedral[0 ], &dihedral[1 ], fpos, fderiv );
323
+ }
324
+ std::cout << " Energy = " << fenergy << " \n " ;
325
+ dump_vec (" deriv" , fderiv, 12 );
326
+ } else if (arg1 == " old-double" ) {
327
+ float fenergy;
328
+ float fderiv[12 ];
329
+ float fpos[12 ];
330
+ copyVec (fpos,pos,12 );
331
+ std::cout << " Old method" << " \n " ;
332
+ for ( size_t nn = 0 ; nn<num; nn++ ) {
333
+ zeroVec (fderiv,12 );
334
+ energy = old_dihedral_energy<double >( &dihedral[0 ], &dihedral[1 ], fpos, fderiv );
335
+ }
336
+ std::cout << " Energy = " << fenergy << " \n " ;
337
+ dump_vec (" deriv" , fderiv, 12 );
338
+ } else if (arg1 == " new-simple" ) {
339
+ float fderiv[12 ];
340
+ for ( size_t nn = 0 ; nn<num; nn++ ) {
341
+ zeroVec (fderiv,12 );
342
+ simple_dihedral_gradient ( 0.0 , 1.0 , 10.0 , 2.0 , 2 , 0 , 3 , 6 , 9 , pos, fderiv );
343
+ }
344
+ dump_vec (" simple deriv via Enzyme" , fderiv, 12 );
345
+ } else if (arg1 == " old-simple" ) {
346
+ float fderiv[12 ];
347
+ for ( size_t nn = 0 ; nn<num; nn++ ) {
348
+ zeroVec (fderiv,12 );
349
+ old_simple_dihedral_energy_gradient ( 0.0 , 1.0 , 10.0 , 2.0 , 2 , 0 , 3 , 6 , 9 , pos, fderiv );
215
350
}
351
+ dump_vec (" simple deriv old code " , fderiv, 12 );
216
352
}
217
- std::cout << " Energy = " << energy << " \n " ;
218
- dump_vec (" deriv" , deriv, 12 );
219
353
220
354
}
0 commit comments