Skip to content

Commit 4558cf2

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 3fd43d2 + 845c7d5 commit 4558cf2

File tree

4 files changed

+243
-70
lines changed

4 files changed

+243
-70
lines changed

tlm3_dihedral/Dihed.cc

+201-67
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#include <iostream>
2+
#include <iomanip>
23
#include <cstdio>
34
#include <string>
45
#include <cmath>
56

7+
#define PI 3.1415926535
8+
69
#define TENM3 10e-3
710
#define MIN(XXX,YYY) ((XXX<YYY)?XXX:YYY)
811
#define MAX(XXX,YYY) ((XXX<YYY)?YYY:XXX)
@@ -22,6 +25,35 @@ void sinNPhiCosNPhi(int n, float* sinNPhi, float* cosNPhi,
2225
return;
2326
}
2427

28+
template <int N, typename Real>
29+
struct SinCos {
30+
static void sinNPhiCosNPhi(Real& sinNPhi, Real& cosNPhi,
31+
Real sinPhi, Real cosPhi ) {
32+
Real sinNm1Phi;
33+
Real cosNm1Phi;
34+
SinCos<N-1,Real>::sinNPhiCosNPhi(sinNm1Phi,cosNm1Phi,sinPhi,cosPhi);
35+
sinNPhi = cosPhi*sinNm1Phi+sinPhi*cosNm1Phi;
36+
cosNPhi = cosPhi*cosNm1Phi-sinPhi*sinNm1Phi;
37+
}
38+
};
39+
40+
template <>
41+
struct SinCos<1,float> {
42+
static void sinNPhiCosNPhi(float& sinNPhi, float& cosNPhi,
43+
float sinPhi, float cosPhi ) {
44+
sinNPhi = sinPhi;
45+
cosNPhi = cosPhi;
46+
}
47+
};
48+
49+
template <>
50+
struct SinCos<1,double> {
51+
static void sinNPhiCosNPhi(double& sinNPhi, double& cosNPhi,
52+
double sinPhi, double cosPhi ) {
53+
sinNPhi = sinPhi;
54+
cosNPhi = cosPhi;
55+
}
56+
};
2557

2658
int enzyme_dup;
2759
int enzyme_out;
@@ -44,9 +76,9 @@ struct Dihedral {
4476
extern double __enzyme_autodiff(void*, ...);
4577

4678
void dump_vec(const std::string& msg, float* deriv, size_t num) {
47-
std::cout << msg << " ";
79+
std::cout << std::setw(20) << msg << " ";
4880
for ( int i=0; i<num; i++ ) {
49-
std::cout << deriv[i] << " ";
81+
std::cout << std::setprecision(4) << std::fixed << std::setw(8) << deriv[i] << " ";
5082
}
5183
std::cout << "\n";
5284
}
@@ -55,19 +87,88 @@ void dump_vec(const std::string& msg, float* deriv, size_t num) {
5587
#define mysqrt(VAR) (std::sqrt(VAR))
5688
#define reciprocal(VAR) (1.0/VAR)
5789

90+
double old_simple_dihedral_energy_gradient( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float* pos, float* deriv ) {
91+
#define USE_EXPLICIT_DECLARES 1
92+
#define DECLARE_FLOAT(VAR) double VAR;
93+
#define Real double
94+
double EraseLinearDihedral;
95+
double x1, y1, z1;
96+
double x2, y2, z2;
97+
double x3, y3, z3;
98+
double x4, y4, z4;
99+
#include "_Dihedral_termDeclares.cc"
100+
#define DIHEDRAL_SET_PARAMETER(VAR) {}
101+
#define DIHEDRAL_SET_POSITION(VAR,IDX,OFFSET) { VAR = pos[IDX+OFFSET]; }
102+
#define DIHEDRAL_ENERGY_ACCUMULATE(VAR) { result += VAR; }
103+
#define DIHEDRAL_FORCE_ACCUMULATE(IDX,OFFSET,VAR) { deriv[IDX+OFFSET] -= VAR; }
104+
double result = 0.0;
105+
double SinNPhi;
106+
double CosNPhi;
107+
bool calcForce = true;
108+
#define DIHEDRAL_CALC_FORCE 1
109+
#include "_Dihedral_termCode.cc"
110+
#undef Real
111+
#undef DECLARE_FLOAT
112+
#undef DIHEDRAL_SET_PARAMETER
113+
#undef DIHEDRAL_SET_POSITION
114+
#undef DIHEDRAL_ENERGY_ACCUMULATE
115+
#undef DIHEDRAL_FORCE_ACCUMULATE
116+
return result;
117+
}
118+
119+
double simple_dihedral_energy( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float* pos ) {
120+
#define USE_EXPLICIT_DECLARES 1
121+
#define DECLARE_FLOAT(VAR) double VAR;
122+
#define Real double
123+
double EraseLinearDihedral;
124+
double x1, y1, z1;
125+
double x2, y2, z2;
126+
double x3, y3, z3;
127+
double x4, y4, z4;
128+
#include "_DihedralEnergy_termDeclares.cc"
129+
#define DIHEDRAL_SET_PARAMETER(VAR) {}
130+
#define DIHEDRAL_SET_POSITION(VAR,IDX,OFFSET) { VAR = pos[IDX+OFFSET]; }
131+
#define DIHEDRAL_ENERGY_ACCUMULATE(VAR) { result += VAR; }
132+
double result = 0.0;
133+
double SinNPhi;
134+
double CosNPhi;
135+
#include "_DihedralEnergy_termCode.cc"
136+
#undef Real
137+
#undef DECLARE_FLOAT
138+
#undef DIHEDRAL_SET_PARAMETER
139+
#undef DIHEDRAL_SET_POSITION
140+
#undef DIHEDRAL_ENERGY_ACCUMULATE
141+
return result;
142+
}
143+
144+
void simple_dihedral_gradient( double sinPhase, double cosPhase, double V, double DN, int IN, int I1, int I2, int I3, int I4, float* pos, float* grad ) {
145+
__enzyme_autodiff( (void*)simple_dihedral_energy,
146+
enzyme_const, sinPhase,
147+
enzyme_const, cosPhase,
148+
enzyme_const, V,
149+
enzyme_const, DN,
150+
enzyme_const, IN,
151+
enzyme_const, I1,
152+
enzyme_const, I2,
153+
enzyme_const, I3,
154+
enzyme_const, I4,
155+
enzyme_dup, pos, grad );
156+
}
157+
158+
template <typename Real>
58159
float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos) {
59160
#define USE_EXPLICIT_DECLARES 1
60-
#define DECLARE_FLOAT(VAR) float VAR;
61-
float EraseLinearDihedral;
62-
float sinPhase;
63-
float cosPhase;
64-
float V;
65-
float DN;
66-
float IN;
67-
float x1, y1, z1;
68-
float x2, y2, z2;
69-
float x3, y3, z3;
70-
float x4, y4, z4;
161+
#define DECLARE_FLOAT(VAR) Real VAR;
162+
Real EraseLinearDihedral;
163+
Real sinPhase;
164+
Real cosPhase;
165+
Real V;
166+
Real DN;
167+
int IN;
168+
Real x1, y1, z1;
169+
Real x2, y2, z2;
170+
Real x3, y3, z3;
171+
Real x4, y4, z4;
71172
#include "_DihedralEnergy_termDeclares.cc"
72173
#define DIHEDRAL_SET_PARAMETER(VAR) { VAR = dihedral->VAR; }
73174
#define DIHEDRAL_SET_POSITION(VAR,IDX,OFFSET) { VAR = pos[IDX+OFFSET]; }
@@ -76,9 +177,9 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
76177
int I2;
77178
int I3;
78179
int I4;
79-
float result = 0.0;
80-
float SinNPhi;
81-
float CosNPhi;
180+
Real result = 0.0;
181+
Real SinNPhi;
182+
Real CosNPhi;
82183
for ( auto dihedral = dihedral_begin; dihedral < dihedral_end; dihedral++ ) {
83184
#include "_DihedralEnergy_termCode.cc"
84185
}
@@ -89,8 +190,9 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
89190
return result;
90191
}
91192

92-
void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos, float* deriv ) {
93-
__enzyme_autodiff( (void*)dihedral_energy,
193+
template <typename Real>
194+
__attribute__((noinline)) void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos, float* deriv ) {
195+
__enzyme_autodiff( (void*)dihedral_energy<Real>,
94196
enzyme_const, dihedral_begin,
95197
enzyme_const, dihedral_end,
96198
enzyme_dup, pos, deriv );
@@ -99,7 +201,7 @@ void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, floa
99201
void finite_diff_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos, float* deriv)
100202
{
101203
// Calculate the energy using the dihedral_energy function
102-
float energy_old = dihedral_energy(dihedral_begin, dihedral_end, pos);
204+
float energy_old = dihedral_energy<float>(dihedral_begin, dihedral_end, pos);
103205

104206
// Iterate over each position in the pos array
105207

@@ -115,41 +217,42 @@ void finite_diff_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_en
115217
posp[i] = pos[i] + TENM3;
116218
posm[i] = pos[i] - TENM3;
117219

118-
float energy_new_p = dihedral_energy(dihedral_begin, dihedral_end, posp);
119-
float energy_new_m = dihedral_energy(dihedral_begin, dihedral_end, posm);
220+
float energy_new_p = dihedral_energy<float>(dihedral_begin, dihedral_end, posp);
221+
float energy_new_m = dihedral_energy<float>(dihedral_begin, dihedral_end, posm);
120222

121223
// Calculate the derivative using finite differences
122224
deriv[i] = (energy_new_p - energy_new_m) / (2.0*TENM3);
123225
}
124226
}
125227

126-
float old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float* pos, float* deriv) {
228+
template <typename Real>
229+
Real old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float* pos, float* deriv) {
127230
#define USE_EXPLICIT_DECLARES 1
128-
#define DECLARE_FLOAT(VAR) float VAR;
129-
float sinPhase;
130-
float cosPhase;
131-
float V;
132-
float DN;
133-
float IN;
134-
float x1, y1, z1;
135-
float x2, y2, z2;
136-
float x3, y3, z3;
137-
float x4, y4, z4;
231+
#define DECLARE_FLOAT(VAR) Real VAR;
232+
Real sinPhase;
233+
Real cosPhase;
234+
Real V;
235+
Real DN;
236+
int IN;
237+
Real x1, y1, z1;
238+
Real x2, y2, z2;
239+
Real x3, y3, z3;
240+
Real x4, y4, z4;
138241
#include "_Dihedral_termDeclares.cc"
139242
#define DIHEDRAL_SET_PARAMETER(VAR) { VAR = dihedral->VAR; }
140243
#define DIHEDRAL_SET_POSITION(VAR,IDX,OFFSET) { VAR = pos[IDX+OFFSET]; }
141244
#define DIHEDRAL_ENERGY_ACCUMULATE(VAR) { result += VAR; }
142-
#define DIHEDRAL_FORCE_ACCUMULATE(IDX,OFFSET,VAR) { deriv[IDX+OFFSET] += VAR; }
245+
#define DIHEDRAL_FORCE_ACCUMULATE(IDX,OFFSET,VAR) { deriv[IDX+OFFSET] -= VAR; }
143246
int I1;
144247
int I2;
145248
int I3;
146249
int I4;
147-
float SinNPhi;
148-
float CosNPhi;
149-
float result = 0.0;
250+
Real SinNPhi;
251+
Real CosNPhi;
252+
Real result = 0.0;
150253
bool calcForce = true;
151254
#define DIHEDRAL_CALC_FORCE 1
152-
float EraseLinearDihedral;
255+
Real EraseLinearDihedral;
153256
for ( auto dihedral = dihedral_begin; dihedral<dihedral_end; dihedral++ ) {
154257
#include "_Dihedral_termCode.cc"
155258
}
@@ -165,56 +268,87 @@ void zeroVec(float* deriv, size_t num) {
165268
deriv[i] = 0.0;
166269
}
167270
}
271+
void copyVec(float* result, float* val, size_t num) {
272+
for (int i=0; i<num; i++ ) {
273+
result[i] = val[i];
274+
}
275+
}
276+
277+
168278
int main( int argc, const char* argv[] ) {
169279
float ANG = 20.0*0.0174533;
170280
float pos[12] = {0.0, 0.0, 1.0,
171-
0.0, 0.0, 0.0,
172-
1.0, 0.0, 0.0,
173-
1.0, (float)-sin(ANG), (float)cos(ANG) };
281+
0.0, 0.0, 0.0,
282+
1.0, 0.0, 0.0,
283+
1.0, (float)-sin(ANG), (float)cos(ANG) };
174284
float deriv[12];
175285
Dihedral dihedral[] = { {0.0, 1.0, 10.0, 2.0, 2, 0, 3, 6, 9 } };
176286

177287
dump_vec("pos", pos, 12);
178288
float energy = 0.0;
179289
std::string arg1(argv[1]);
180-
int donew = 0;
181-
182-
if(arg1 == "new")
183-
donew = 1;
184-
else if(arg1 == "fin")
185-
donew = 2;
186-
else
187-
donew = 3;
188290

189291
size_t num = atoi(argv[2]);
190-
if (donew == 1)// for new
191-
{
292+
if (arg1 == "new-float") {
192293
std::cout << "New method" << "\n";
193294
// energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
194295
for ( size_t nn = 0; nn<num; nn++ ) {
195296
zeroVec(deriv,12);
196-
grad_dihedral_energy( &dihedral[0], &dihedral[1], pos, deriv);
297+
grad_dihedral_energy<float>( &dihedral[0], &dihedral[1], pos, deriv);
197298
}
198-
}
199-
200-
else if (donew == 2)// for finite diff
201-
{
202-
for ( size_t nn = 0; nn<num; nn++ ) {
203-
zeroVec(deriv,12);
204-
finite_diff_dihedral_energy(&dihedral[0], &dihedral[1], pos, deriv);
205-
}
299+
dump_vec("new float deriv", deriv, 12);
300+
} else if (arg1 == "new-double") {
301+
std::cout << "New method" << "\n";
302+
// energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
303+
for ( size_t nn = 0; nn<num; nn++ ) {
304+
zeroVec(deriv,12);
305+
grad_dihedral_energy<double>( &dihedral[0], &dihedral[1], pos, deriv);
306+
}
307+
dump_vec("new double deriv", deriv, 12);
308+
} else if ( arg1 == "fin") {
309+
for ( size_t nn = 0; nn<num; nn++ ) {
310+
zeroVec(deriv,12);
311+
finite_diff_dihedral_energy(&dihedral[0], &dihedral[1], pos, deriv);
206312
}
207-
208-
209-
else//for old
210-
{
313+
dump_vec("fin deriv", deriv, 12);
314+
} else if (arg1 == "old-float") {
315+
float fenergy;
316+
float fderiv[12];
317+
float fpos[12];
318+
copyVec(fpos,pos,12);
211319
std::cout << "Old method" << "\n";
212320
for ( size_t nn = 0; nn<num; nn++ ) {
213-
zeroVec(deriv,12);
214-
energy = old_dihedral_energy( &dihedral[0], &dihedral[1], pos, deriv );
321+
zeroVec(fderiv,12);
322+
energy = old_dihedral_energy<float>( &dihedral[0], &dihedral[1], fpos, fderiv );
323+
}
324+
std::cout << "Energy = " << fenergy << "\n";
325+
dump_vec("deriv", fderiv, 12);
326+
} else if (arg1 == "old-double") {
327+
float fenergy;
328+
float fderiv[12];
329+
float fpos[12];
330+
copyVec(fpos,pos,12);
331+
std::cout << "Old method" << "\n";
332+
for ( size_t nn = 0; nn<num; nn++ ) {
333+
zeroVec(fderiv,12);
334+
energy = old_dihedral_energy<double>( &dihedral[0], &dihedral[1], fpos, fderiv );
335+
}
336+
std::cout << "Energy = " << fenergy << "\n";
337+
dump_vec("deriv", fderiv, 12);
338+
} else if (arg1 == "new-simple") {
339+
float fderiv[12];
340+
for ( size_t nn = 0; nn<num; nn++ ) {
341+
zeroVec(fderiv,12);
342+
simple_dihedral_gradient( 0.0, 1.0, 10.0, 2.0, 2, 0, 3, 6, 9, pos, fderiv );
343+
}
344+
dump_vec("simple deriv via Enzyme", fderiv, 12 );
345+
} else if (arg1 == "old-simple") {
346+
float fderiv[12];
347+
for ( size_t nn = 0; nn<num; nn++ ) {
348+
zeroVec(fderiv,12);
349+
old_simple_dihedral_energy_gradient( 0.0, 1.0, 10.0, 2.0, 2, 0, 3, 6, 9, pos, fderiv );
215350
}
351+
dump_vec("simple deriv old code ", fderiv, 12 );
216352
}
217-
std::cout << "Energy = " << energy << "\n";
218-
dump_vec("deriv", deriv, 12);
219353

220354
}

0 commit comments

Comments
 (0)