1
1
#include < iostream>
2
+ #include < iomanip>
2
3
#include < cstdio>
3
4
#include < string>
4
5
#include < cmath>
5
6
7
+ #define PI 3.1415926535
8
+
6
9
#define TENM3 10e-3
7
10
#define MIN (XXX,YYY ) ((XXX<YYY)?XXX:YYY)
8
11
#define MAX (XXX,YYY ) ((XXX<YYY)?YYY:XXX)
@@ -22,27 +25,36 @@ void sinNPhiCosNPhi(int n, float* sinNPhi, float* cosNPhi,
22
25
return ;
23
26
}
24
27
25
- template <int N>
28
+ template <int N, typename Real >
26
29
struct SinCos {
27
- static void sinNPhiCosNPhi (float & sinNPhi, float & cosNPhi,
28
- float sinPhi, float cosPhi ) {
29
- float sinNm1Phi;
30
- float cosNm1Phi;
31
- SinCos<N-1 >::sinNPhiCosNPhi (sinNm1Phi,cosNm1Phi,sinPhi,cosPhi);
30
+ static void sinNPhiCosNPhi (Real & sinNPhi, Real & cosNPhi,
31
+ Real sinPhi, Real cosPhi ) {
32
+ Real sinNm1Phi;
33
+ Real cosNm1Phi;
34
+ SinCos<N-1 ,Real >::sinNPhiCosNPhi (sinNm1Phi,cosNm1Phi,sinPhi,cosPhi);
32
35
sinNPhi = cosPhi*sinNm1Phi+sinPhi*cosNm1Phi;
33
36
cosNPhi = cosPhi*cosNm1Phi-sinPhi*sinNm1Phi;
34
37
}
35
38
};
36
39
37
40
template <>
38
- struct SinCos <1 > {
41
+ struct SinCos <1 , float > {
39
42
static void sinNPhiCosNPhi (float & sinNPhi, float & cosNPhi,
40
43
float sinPhi, float cosPhi ) {
41
44
sinNPhi = sinPhi;
42
45
cosNPhi = cosPhi;
43
46
}
44
47
};
45
48
49
+ template <>
50
+ struct SinCos <1 ,double > {
51
+ static void sinNPhiCosNPhi (double & sinNPhi, double & cosNPhi,
52
+ double sinPhi, double cosPhi ) {
53
+ sinNPhi = sinPhi;
54
+ cosNPhi = cosPhi;
55
+ }
56
+ };
57
+
46
58
int enzyme_dup;
47
59
int enzyme_out;
48
60
int enzyme_const;
@@ -64,9 +76,9 @@ struct Dihedral {
64
76
extern double __enzyme_autodiff (void *, ...);
65
77
66
78
void dump_vec (const std::string& msg, float * deriv, size_t num) {
67
- std::cout << msg << " " ;
79
+ std::cout << std::setw ( 10 ) << msg << " " ;
68
80
for ( int i=0 ; i<num; i++ ) {
69
- std::cout << deriv[i] << " " ;
81
+ std::cout << std::setprecision ( 4 ) << std::fixed << std::setw ( 8 ) << deriv[i] << " " ;
70
82
}
71
83
std::cout << " \n " ;
72
84
}
@@ -75,19 +87,20 @@ void dump_vec(const std::string& msg, float* deriv, size_t num) {
75
87
#define mysqrt (VAR ) (std::sqrt(VAR))
76
88
#define reciprocal (VAR ) (1.0 /VAR)
77
89
90
+ template <typename Real>
78
91
float dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos) {
79
92
#define USE_EXPLICIT_DECLARES 1
80
- #define DECLARE_FLOAT (VAR ) float VAR;
81
- float EraseLinearDihedral;
82
- float sinPhase;
83
- float cosPhase;
84
- float V;
85
- float DN;
93
+ #define DECLARE_FLOAT (VAR ) Real VAR;
94
+ Real EraseLinearDihedral;
95
+ Real sinPhase;
96
+ Real cosPhase;
97
+ Real V;
98
+ Real DN;
86
99
int IN;
87
- float x1, y1, z1;
88
- float x2, y2, z2;
89
- float x3, y3, z3;
90
- float x4, y4, z4;
100
+ Real x1, y1, z1;
101
+ Real x2, y2, z2;
102
+ Real x3, y3, z3;
103
+ Real x4, y4, z4;
91
104
#include " _DihedralEnergy_termDeclares.cc"
92
105
#define DIHEDRAL_SET_PARAMETER (VAR ) { VAR = dihedral->VAR ; }
93
106
#define DIHEDRAL_SET_POSITION (VAR,IDX,OFFSET ) { VAR = pos[IDX+OFFSET]; }
@@ -96,9 +109,9 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
96
109
int I2;
97
110
int I3;
98
111
int I4;
99
- float result = 0.0 ;
100
- float SinNPhi;
101
- float CosNPhi;
112
+ Real result = 0.0 ;
113
+ Real SinNPhi;
114
+ Real CosNPhi;
102
115
for ( auto dihedral = dihedral_begin; dihedral < dihedral_end; dihedral++ ) {
103
116
#include " _DihedralEnergy_termCode.cc"
104
117
}
@@ -109,8 +122,9 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
109
122
return result;
110
123
}
111
124
125
+ template <typename Real>
112
126
void grad_dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv ) {
113
- __enzyme_autodiff ( (void *)dihedral_energy,
127
+ __enzyme_autodiff ( (void *)dihedral_energy<Real> ,
114
128
enzyme_const, dihedral_begin,
115
129
enzyme_const, dihedral_end,
116
130
enzyme_dup, pos, deriv );
@@ -119,7 +133,7 @@ void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, floa
119
133
void finite_diff_dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv)
120
134
{
121
135
// Calculate the energy using the dihedral_energy function
122
- float energy_old = dihedral_energy (dihedral_begin, dihedral_end, pos);
136
+ float energy_old = dihedral_energy< float > (dihedral_begin, dihedral_end, pos);
123
137
124
138
// Iterate over each position in the pos array
125
139
@@ -135,26 +149,27 @@ void finite_diff_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_en
135
149
posp[i] = pos[i] + TENM3;
136
150
posm[i] = pos[i] - TENM3;
137
151
138
- float energy_new_p = dihedral_energy (dihedral_begin, dihedral_end, posp);
139
- float energy_new_m = dihedral_energy (dihedral_begin, dihedral_end, posm);
152
+ float energy_new_p = dihedral_energy< float > (dihedral_begin, dihedral_end, posp);
153
+ float energy_new_m = dihedral_energy< float > (dihedral_begin, dihedral_end, posm);
140
154
141
155
// Calculate the derivative using finite differences
142
156
deriv[i] = (energy_new_p - energy_new_m) / (2.0 *TENM3);
143
157
}
144
158
}
145
159
146
- float old_dihedral_energy (Dihedral* dihedral_begin, Dihedral*dihedral_end, float * pos, float * deriv) {
160
+ template <typename Real>
161
+ Real old_dihedral_energy (Dihedral* dihedral_begin, Dihedral*dihedral_end, float * pos, float * deriv) {
147
162
#define USE_EXPLICIT_DECLARES 1
148
- #define DECLARE_FLOAT (VAR ) float VAR;
149
- float sinPhase;
150
- float cosPhase;
151
- float V;
152
- float DN;
163
+ #define DECLARE_FLOAT (VAR ) Real VAR;
164
+ Real sinPhase;
165
+ Real cosPhase;
166
+ Real V;
167
+ Real DN;
153
168
int IN;
154
- float x1, y1, z1;
155
- float x2, y2, z2;
156
- float x3, y3, z3;
157
- float x4, y4, z4;
169
+ Real x1, y1, z1;
170
+ Real x2, y2, z2;
171
+ Real x3, y3, z3;
172
+ Real x4, y4, z4;
158
173
#include " _Dihedral_termDeclares.cc"
159
174
#define DIHEDRAL_SET_PARAMETER (VAR ) { VAR = dihedral->VAR ; }
160
175
#define DIHEDRAL_SET_POSITION (VAR,IDX,OFFSET ) { VAR = pos[IDX+OFFSET]; }
@@ -164,12 +179,12 @@ float old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float
164
179
int I2;
165
180
int I3;
166
181
int I4;
167
- float SinNPhi;
168
- float CosNPhi;
169
- float result = 0.0 ;
182
+ Real SinNPhi;
183
+ Real CosNPhi;
184
+ Real result = 0.0 ;
170
185
bool calcForce = true ;
171
186
#define DIHEDRAL_CALC_FORCE 1
172
- float EraseLinearDihedral;
187
+ Real EraseLinearDihedral;
173
188
for ( auto dihedral = dihedral_begin; dihedral<dihedral_end; dihedral++ ) {
174
189
#include " _Dihedral_termCode.cc"
175
190
}
@@ -185,56 +200,69 @@ void zeroVec(float* deriv, size_t num) {
185
200
deriv[i] = 0.0 ;
186
201
}
187
202
}
203
+ void copyVec (float * result, float * val, size_t num) {
204
+ for (int i=0 ; i<num; i++ ) {
205
+ result[i] = val[i];
206
+ }
207
+ }
208
+
209
+
188
210
int main ( int argc, const char * argv[] ) {
189
211
float ANG = 20.0 *0.0174533 ;
190
212
float pos[12 ] = {0.0 , 0.0 , 1.0 ,
191
- 0.0 , 0.0 , 0.0 ,
192
- 1.0 , 0.0 , 0.0 ,
193
- 1.0 , (float )-sin (ANG), (float )cos (ANG) };
213
+ 0.0 , 0.0 , 0.0 ,
214
+ 1.0 , 0.0 , 0.0 ,
215
+ 1.0 , (float )-sin (ANG), (float )cos (ANG) };
194
216
float deriv[12 ];
195
217
Dihedral dihedral[] = { {0.0 , 1.0 , 10.0 , 2.0 , 2 , 0 , 3 , 6 , 9 } };
196
218
197
219
dump_vec (" pos" , pos, 12 );
198
220
float energy = 0.0 ;
199
221
std::string arg1 (argv[1 ]);
200
- int donew = 0 ;
201
-
202
- if (arg1 == " new" )
203
- donew = 1 ;
204
- else if (arg1 == " fin" )
205
- donew = 2 ;
206
- else
207
- donew = 3 ;
208
222
209
223
size_t num = atoi (argv[2 ]);
210
- if (donew == 1 )// for new
211
- {
224
+ if (arg1 == " new-float" ) {
212
225
std::cout << " New method" << " \n " ;
213
226
// energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
214
227
for ( size_t nn = 0 ; nn<num; nn++ ) {
215
228
zeroVec (deriv,12 );
216
- grad_dihedral_energy ( &dihedral[0 ], &dihedral[1 ], pos, deriv);
229
+ grad_dihedral_energy< float > ( &dihedral[0 ], &dihedral[1 ], pos, deriv);
217
230
}
218
- }
219
-
220
- else if (donew == 2 )// for finite diff
221
- {
222
- for ( size_t nn = 0 ; nn<num; nn++ ) {
223
- zeroVec (deriv,12 );
224
- finite_diff_dihedral_energy (&dihedral[0 ], &dihedral[1 ], pos, deriv);
225
- }
231
+ } else if (arg1 == " new-double" ) {
232
+ std::cout << " New method" << " \n " ;
233
+ // energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
234
+ for ( size_t nn = 0 ; nn<num; nn++ ) {
235
+ zeroVec (deriv,12 );
236
+ grad_dihedral_energy<double >( &dihedral[0 ], &dihedral[1 ], pos, deriv);
237
+ }
238
+ } else if ( arg1 == " fin" ) {
239
+ for ( size_t nn = 0 ; nn<num; nn++ ) {
240
+ zeroVec (deriv,12 );
241
+ finite_diff_dihedral_energy (&dihedral[0 ], &dihedral[1 ], pos, deriv);
226
242
}
227
-
228
-
229
- else // for old
230
- {
243
+ } else if (arg1 == " old-float" ) {
244
+ float fenergy;
245
+ float fderiv[12 ];
246
+ float fpos[12 ];
247
+ copyVec (fpos,pos,12 );
231
248
std::cout << " Old method" << " \n " ;
232
249
for ( size_t nn = 0 ; nn<num; nn++ ) {
233
- zeroVec (deriv ,12 );
234
- energy = old_dihedral_energy ( &dihedral[0 ], &dihedral[1 ], pos, deriv );
250
+ zeroVec (fderiv ,12 );
251
+ energy = old_dihedral_energy< float > ( &dihedral[0 ], &dihedral[1 ], fpos, fderiv );
235
252
}
253
+ std::cout << " Energy = " << fenergy << " \n " ;
254
+ dump_vec (" deriv" , fderiv, 12 );
255
+ } else if (arg1 == " old-double" ) {
256
+ float fenergy;
257
+ float fderiv[12 ];
258
+ float fpos[12 ];
259
+ copyVec (fpos,pos,12 );
260
+ std::cout << " Old method" << " \n " ;
261
+ for ( size_t nn = 0 ; nn<num; nn++ ) {
262
+ zeroVec (fderiv,12 );
263
+ energy = old_dihedral_energy<double >( &dihedral[0 ], &dihedral[1 ], fpos, fderiv );
264
+ }
265
+ std::cout << " Energy = " << fenergy << " \n " ;
266
+ dump_vec (" deriv" , fderiv, 12 );
236
267
}
237
- std::cout << " Energy = " << energy << " \n " ;
238
- dump_vec (" deriv" , deriv, 12 );
239
-
240
268
}
0 commit comments