77#define MIN (XXX,YYY ) ((XXX<YYY)?XXX:YYY)
88#define MAX (XXX,YYY ) ((XXX<YYY)?YYY:XXX)
99
10- void sinNPhiCosNPhi (int n, double * sinNPhi, double * cosNPhi,
11- double sinPhi, double cosPhi )
10+ void sinNPhiCosNPhi (int n, float * sinNPhi, float * cosNPhi,
11+ float sinPhi, float cosPhi )
1212{
13- double sinNm1Phi, cosNm1Phi;
13+ float sinNm1Phi, cosNm1Phi;
1414 if ( n==1 ) {
1515 *sinNPhi = sinPhi;
1616 *cosNPhi = cosPhi;
@@ -32,7 +32,7 @@ struct Dihedral {
3232 float cosPhase;
3333 float V;
3434 float DN;
35- float IN;
35+ int IN;
3636 int I1;
3737 int I2;
3838 int I3;
@@ -90,12 +90,39 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
9090}
9191
9292void grad_dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv ) {
93- __enzyme_autodiff ( (void *)stretch_energy ,
93+ __enzyme_autodiff ( (void *)dihedral_energy ,
9494 enzyme_const, dihedral_begin,
9595 enzyme_const, dihedral_end,
9696 enzyme_dup, pos, deriv );
9797}
9898
99+ void finite_diff_dihedral_energy (Dihedral* dihedral_begin, Dihedral* dihedral_end, float * pos, float * deriv)
100+ {
101+ // Calculate the energy using the dihedral_energy function
102+ float energy_old = dihedral_energy (dihedral_begin, dihedral_end, pos);
103+
104+ // Iterate over each position in the pos array
105+
106+ for (int i = 0 ; i < 12 ; i++)
107+ {
108+ float posp[12 ];
109+ float posm[12 ];
110+ for (int j = 0 ; j<12 ; j++)
111+ {
112+ posp[j] = pos[j];
113+ posm[j] = pos[j];
114+ }
115+ posp[i] = pos[i] + TENM3;
116+ posm[i] = pos[i] - TENM3;
117+
118+ float energy_new_p = dihedral_energy (dihedral_begin, dihedral_end, posp);
119+ float energy_new_m = dihedral_energy (dihedral_begin, dihedral_end, posm);
120+
121+ // Calculate the derivative using finite differences
122+ deriv[i] = (energy_new_p - energy_new_m) / (2.0 *TENM3);
123+ }
124+ }
125+
99126float old_dihedral_energy (Dihedral* dihedral_begin, Dihedral*dihedral_end, float * pos, float * deriv) {
100127#define USE_EXPLICIT_DECLARES 1
101128#define DECLARE_FLOAT (VAR ) float VAR;
@@ -117,9 +144,12 @@ float old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float
117144 int I2;
118145 int I3;
119146 int I4;
147+ float SinNPhi;
148+ float CosNPhi;
120149 float result = 0.0 ;
121150 bool calcForce = true ;
122151#define DIHEDRAL_CALC_FORCE 1
152+ float EraseLinearDihedral;
123153 for ( auto dihedral = dihedral_begin; dihedral<dihedral_end; dihedral++ ) {
124154#include " _Dihedral_termCode.cc"
125155 }
@@ -136,27 +166,52 @@ void zeroVec(float* deriv, size_t num) {
136166 }
137167}
138168int main ( int argc, const char * argv[] ) {
139- float pos[12 ] = {0.0 , 19.0 , 3.0 , 10.0 , 7.0 , 80.0 ,
140- 20.0 , 15.0 , 17.0 , 25.0 , 44.0 , 23.0 };
169+ float ANG = 20.0 *0.0174533 ;
170+ float pos[12 ] = {0.0 , 0.0 , 1.0 ,
171+ 0.0 , 0.0 , 0.0 ,
172+ 1.0 , 0.0 , 0.0 ,
173+ 1.0 , (float )-sin (ANG), (float )cos (ANG) };
141174 float deriv[12 ];
142- Stretch stretch [] = { {10 .0 , 2 .0 , 0 , 3 }, { 20. 0 , 3.0 , 6 , 9 } };
175+ Dihedral dihedral [] = { {0 .0 , 1 .0 , 10. 0 , 2.0 , 2 , 0 , 3 , 6 , 9 } };
143176
144177 dump_vec (" pos" , pos, 12 );
145178 float energy = 0.0 ;
146179 std::string arg1 (argv[1 ]);
147- bool donew = (arg1 == " new" );
180+ int donew = 0 ;
181+
182+ if (arg1 == " new" )
183+ donew = 1 ;
184+ else if (arg1 == " fin" )
185+ donew = 2 ;
186+ else
187+ donew = 3 ;
188+
148189 size_t num = atoi (argv[2 ]);
149- if (donew) {
190+ if (donew == 1 )// for new
191+ {
150192 std::cout << " New method" << " \n " ;
193+ // energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
151194 for ( size_t nn = 0 ; nn<num; nn++ ) {
152195 zeroVec (deriv,12 );
153- grad_dihedral_energy ( &dihedral[0 ], &dihedral[2 ], pos, deriv);
196+ grad_dihedral_energy ( &dihedral[0 ], &dihedral[1 ], pos, deriv);
154197 }
155- } else {
198+ }
199+
200+ else if (donew == 2 )// for finite diff
201+ {
202+ for ( size_t nn = 0 ; nn<num; nn++ ) {
203+ zeroVec (deriv,12 );
204+ finite_diff_dihedral_energy (&dihedral[0 ], &dihedral[1 ], pos, deriv);
205+ }
206+ }
207+
208+
209+ else // for old
210+ {
156211 std::cout << " Old method" << " \n " ;
157212 for ( size_t nn = 0 ; nn<num; nn++ ) {
158213 zeroVec (deriv,12 );
159- energy = old_dihedral_energy ( &dihedral[0 ], &dihedral[2 ], pos, deriv );
214+ energy = old_dihedral_energy ( &dihedral[0 ], &dihedral[1 ], pos, deriv );
160215 }
161216 }
162217 std::cout << " Energy = " << energy << " \n " ;
0 commit comments