Skip to content

Commit da111ba

Browse files
committed
added finite deriv. calculations
1 parent 1ec8353 commit da111ba

File tree

1 file changed

+68
-13
lines changed

1 file changed

+68
-13
lines changed

tlm3_dihedral/Dihed.cc

+68-13
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
#define MIN(XXX,YYY) ((XXX<YYY)?XXX:YYY)
88
#define MAX(XXX,YYY) ((XXX<YYY)?YYY:XXX)
99

10-
void sinNPhiCosNPhi(int n, double* sinNPhi, double* cosNPhi,
11-
double sinPhi, double cosPhi )
10+
void sinNPhiCosNPhi(int n, float* sinNPhi, float* cosNPhi,
11+
float sinPhi, float cosPhi )
1212
{
13-
double sinNm1Phi, cosNm1Phi;
13+
float sinNm1Phi, cosNm1Phi;
1414
if ( n==1 ) {
1515
*sinNPhi = sinPhi;
1616
*cosNPhi = cosPhi;
@@ -32,7 +32,7 @@ struct Dihedral {
3232
float cosPhase;
3333
float V;
3434
float DN;
35-
float IN;
35+
int IN;
3636
int I1;
3737
int I2;
3838
int I3;
@@ -90,12 +90,39 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
9090
}
9191

9292
void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos, float* deriv ) {
93-
__enzyme_autodiff( (void*)stretch_energy,
93+
__enzyme_autodiff( (void*)dihedral_energy,
9494
enzyme_const, dihedral_begin,
9595
enzyme_const, dihedral_end,
9696
enzyme_dup, pos, deriv );
9797
}
9898

99+
void finite_diff_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos, float* deriv)
100+
{
101+
// Calculate the energy using the dihedral_energy function
102+
float energy_old = dihedral_energy(dihedral_begin, dihedral_end, pos);
103+
104+
// Iterate over each position in the pos array
105+
106+
for (int i = 0; i < 12; i++)
107+
{
108+
float posp[12];
109+
float posm[12];
110+
for(int j = 0; j<12; j++)
111+
{
112+
posp[j] = pos[j];
113+
posm[j] = pos[j];
114+
}
115+
posp[i] = pos[i] + TENM3;
116+
posm[i] = pos[i] - TENM3;
117+
118+
float energy_new_p = dihedral_energy(dihedral_begin, dihedral_end, posp);
119+
float energy_new_m = dihedral_energy(dihedral_begin, dihedral_end, posm);
120+
121+
// Calculate the derivative using finite differences
122+
deriv[i] = (energy_new_p - energy_new_m) / (2.0*TENM3);
123+
}
124+
}
125+
99126
float old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float* pos, float* deriv) {
100127
#define USE_EXPLICIT_DECLARES 1
101128
#define DECLARE_FLOAT(VAR) float VAR;
@@ -117,9 +144,12 @@ float old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float
117144
int I2;
118145
int I3;
119146
int I4;
147+
float SinNPhi;
148+
float CosNPhi;
120149
float result = 0.0;
121150
bool calcForce = true;
122151
#define DIHEDRAL_CALC_FORCE 1
152+
float EraseLinearDihedral;
123153
for ( auto dihedral = dihedral_begin; dihedral<dihedral_end; dihedral++ ) {
124154
#include "_Dihedral_termCode.cc"
125155
}
@@ -136,27 +166,52 @@ void zeroVec(float* deriv, size_t num) {
136166
}
137167
}
138168
int main( int argc, const char* argv[] ) {
139-
float pos[12] = {0.0, 19.0, 3.0, 10.0, 7.0, 80.0,
140-
20.0, 15.0, 17.0, 25.0, 44.0, 23.0 };
169+
float ANG = 20.0*0.0174533;
170+
float pos[12] = {0.0, 0.0, 1.0,
171+
0.0, 0.0, 0.0,
172+
1.0, 0.0, 0.0,
173+
1.0, (float)-sin(ANG), (float)cos(ANG) };
141174
float deriv[12];
142-
Stretch stretch[] = { {10.0, 2.0, 0, 3}, {20.0, 3.0, 6, 9} };
175+
Dihedral dihedral[] = { {0.0, 1.0, 10.0, 2.0, 2, 0, 3, 6, 9 } };
143176

144177
dump_vec("pos", pos, 12);
145178
float energy = 0.0;
146179
std::string arg1(argv[1]);
147-
bool donew = (arg1 == "new");
180+
int donew = 0;
181+
182+
if(arg1 == "new")
183+
donew = 1;
184+
else if(arg1 == "fin")
185+
donew = 2;
186+
else
187+
donew = 3;
188+
148189
size_t num = atoi(argv[2]);
149-
if (donew) {
190+
if (donew == 1)// for new
191+
{
150192
std::cout << "New method" << "\n";
193+
// energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
151194
for ( size_t nn = 0; nn<num; nn++ ) {
152195
zeroVec(deriv,12);
153-
grad_dihedral_energy( &dihedral[0], &dihedral[2], pos, deriv);
196+
grad_dihedral_energy( &dihedral[0], &dihedral[1], pos, deriv);
154197
}
155-
} else {
198+
}
199+
200+
else if (donew == 2)// for finite diff
201+
{
202+
for ( size_t nn = 0; nn<num; nn++ ) {
203+
zeroVec(deriv,12);
204+
finite_diff_dihedral_energy(&dihedral[0], &dihedral[1], pos, deriv);
205+
}
206+
}
207+
208+
209+
else//for old
210+
{
156211
std::cout << "Old method" << "\n";
157212
for ( size_t nn = 0; nn<num; nn++ ) {
158213
zeroVec(deriv,12);
159-
energy = old_dihedral_energy( &dihedral[0], &dihedral[2], pos, deriv );
214+
energy = old_dihedral_energy( &dihedral[0], &dihedral[1], pos, deriv );
160215
}
161216
}
162217
std::cout << "Energy = " << energy << "\n";

0 commit comments

Comments
 (0)