7
7
#define MIN (XXX,YYY ) ((XXX<YYY)?XXX:YYY)
8
8
#define MAX (XXX,YYY ) ((XXX<YYY)?YYY:XXX)
9
9
10
- void sinNPhiCosNPhi (int n, double * sinNPhi, double * cosNPhi,
11
- double sinPhi, double cosPhi )
10
+ void sinNPhiCosNPhi (int n, float * sinNPhi, float * cosNPhi,
11
+ float sinPhi, float cosPhi )
12
12
{
13
- double sinNm1Phi, cosNm1Phi;
13
+ float sinNm1Phi, cosNm1Phi;
14
14
if ( n==1 ) {
15
15
*sinNPhi = sinPhi;
16
16
*cosNPhi = cosPhi;
@@ -32,7 +32,7 @@ struct Dihedral {
32
32
float cosPhase;
33
33
float V;
34
34
float DN;
35
- float IN;
35
+ int IN;
36
36
int I1;
37
37
int I2;
38
38
int I3;
@@ -90,12 +90,39 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
90
90
}
91
91
92
92
void grad_dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv ) {
93
- __enzyme_autodiff ( (void *)stretch_energy ,
93
+ __enzyme_autodiff ( (void *)dihedral_energy ,
94
94
enzyme_const, dihedral_begin,
95
95
enzyme_const, dihedral_end,
96
96
enzyme_dup, pos, deriv );
97
97
}
98
98
99
+ void finite_diff_dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv)
100
+ {
101
+ // Calculate the energy using the dihedral_energy function
102
+ float energy_old = dihedral_energy (dihedral_begin, dihedral_end, pos);
103
+
104
+ // Iterate over each position in the pos array
105
+
106
+ for (int i = 0 ; i < 12 ; i++)
107
+ {
108
+ float posp[12 ];
109
+ float posm[12 ];
110
+ for (int j = 0 ; j<12 ; j++)
111
+ {
112
+ posp[j] = pos[j];
113
+ posm[j] = pos[j];
114
+ }
115
+ posp[i] = pos[i] + TENM3;
116
+ posm[i] = pos[i] - TENM3;
117
+
118
+ float energy_new_p = dihedral_energy (dihedral_begin, dihedral_end, posp);
119
+ float energy_new_m = dihedral_energy (dihedral_begin, dihedral_end, posm);
120
+
121
+ // Calculate the derivative using finite differences
122
+ deriv[i] = (energy_new_p - energy_new_m) / (2.0 *TENM3);
123
+ }
124
+ }
125
+
99
126
float old_dihedral_energy (Dihedral* dihedral_begin, Dihedral*dihedral_end, float * pos, float * deriv) {
100
127
#define USE_EXPLICIT_DECLARES 1
101
128
#define DECLARE_FLOAT (VAR ) float VAR;
@@ -117,9 +144,12 @@ float old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float
117
144
int I2;
118
145
int I3;
119
146
int I4;
147
+ float SinNPhi;
148
+ float CosNPhi;
120
149
float result = 0.0 ;
121
150
bool calcForce = true ;
122
151
#define DIHEDRAL_CALC_FORCE 1
152
+ float EraseLinearDihedral;
123
153
for ( auto dihedral = dihedral_begin; dihedral<dihedral_end; dihedral++ ) {
124
154
#include " _Dihedral_termCode.cc"
125
155
}
@@ -136,27 +166,52 @@ void zeroVec(float* deriv, size_t num) {
136
166
}
137
167
}
138
168
int main ( int argc, const char * argv[] ) {
139
- float pos[12 ] = {0.0 , 19.0 , 3.0 , 10.0 , 7.0 , 80.0 ,
140
- 20.0 , 15.0 , 17.0 , 25.0 , 44.0 , 23.0 };
169
+ float ANG = 20.0 *0.0174533 ;
170
+ float pos[12 ] = {0.0 , 0.0 , 1.0 ,
171
+ 0.0 , 0.0 , 0.0 ,
172
+ 1.0 , 0.0 , 0.0 ,
173
+ 1.0 , (float )-sin (ANG), (float )cos (ANG) };
141
174
float deriv[12 ];
142
- Stretch stretch [] = { {10 .0 , 2 .0 , 0 , 3 }, { 20. 0 , 3.0 , 6 , 9 } };
175
+ Dihedral dihedral [] = { {0 .0 , 1 .0 , 10. 0 , 2.0 , 2 , 0 , 3 , 6 , 9 } };
143
176
144
177
dump_vec (" pos" , pos, 12 );
145
178
float energy = 0.0 ;
146
179
std::string arg1 (argv[1 ]);
147
- bool donew = (arg1 == " new" );
180
+ int donew = 0 ;
181
+
182
+ if (arg1 == " new" )
183
+ donew = 1 ;
184
+ else if (arg1 == " fin" )
185
+ donew = 2 ;
186
+ else
187
+ donew = 3 ;
188
+
148
189
size_t num = atoi (argv[2 ]);
149
- if (donew) {
190
+ if (donew == 1 )// for new
191
+ {
150
192
std::cout << " New method" << " \n " ;
193
+ // energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
151
194
for ( size_t nn = 0 ; nn<num; nn++ ) {
152
195
zeroVec (deriv,12 );
153
- grad_dihedral_energy ( &dihedral[0 ], &dihedral[2 ], pos, deriv);
196
+ grad_dihedral_energy ( &dihedral[0 ], &dihedral[1 ], pos, deriv);
154
197
}
155
- } else {
198
+ }
199
+
200
+ else if (donew == 2 )// for finite diff
201
+ {
202
+ for ( size_t nn = 0 ; nn<num; nn++ ) {
203
+ zeroVec (deriv,12 );
204
+ finite_diff_dihedral_energy (&dihedral[0 ], &dihedral[1 ], pos, deriv);
205
+ }
206
+ }
207
+
208
+
209
+ else // for old
210
+ {
156
211
std::cout << " Old method" << " \n " ;
157
212
for ( size_t nn = 0 ; nn<num; nn++ ) {
158
213
zeroVec (deriv,12 );
159
- energy = old_dihedral_energy ( &dihedral[0 ], &dihedral[2 ], pos, deriv );
214
+ energy = old_dihedral_energy ( &dihedral[0 ], &dihedral[1 ], pos, deriv );
160
215
}
161
216
}
162
217
std::cout << " Energy = " << energy << " \n " ;
0 commit comments