Skip to content

Commit 402dc3e

Browse files
author
meister
committed
Added more options
1 parent 4aa87d7 commit 402dc3e

File tree

4 files changed

+112
-84
lines changed

4 files changed

+112
-84
lines changed

tlm3_dihedral/Dihed.cc

Lines changed: 99 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#include <iostream>
2+
#include <iomanip>
23
#include <cstdio>
34
#include <string>
45
#include <cmath>
56

7+
#define PI 3.1415926535
8+
69
#define TENM3 10e-3
710
#define MIN(XXX,YYY) ((XXX<YYY)?XXX:YYY)
811
#define MAX(XXX,YYY) ((XXX<YYY)?YYY:XXX)
@@ -22,27 +25,36 @@ void sinNPhiCosNPhi(int n, float* sinNPhi, float* cosNPhi,
2225
return;
2326
}
2427

25-
template <int N>
28+
template <int N, typename Real>
2629
struct SinCos {
27-
static void sinNPhiCosNPhi(float& sinNPhi, float& cosNPhi,
28-
float sinPhi, float cosPhi ) {
29-
float sinNm1Phi;
30-
float cosNm1Phi;
31-
SinCos<N-1>::sinNPhiCosNPhi(sinNm1Phi,cosNm1Phi,sinPhi,cosPhi);
30+
static void sinNPhiCosNPhi(Real& sinNPhi, Real& cosNPhi,
31+
Real sinPhi, Real cosPhi ) {
32+
Real sinNm1Phi;
33+
Real cosNm1Phi;
34+
SinCos<N-1,Real>::sinNPhiCosNPhi(sinNm1Phi,cosNm1Phi,sinPhi,cosPhi);
3235
sinNPhi = cosPhi*sinNm1Phi+sinPhi*cosNm1Phi;
3336
cosNPhi = cosPhi*cosNm1Phi-sinPhi*sinNm1Phi;
3437
}
3538
};
3639

3740
template <>
38-
struct SinCos<1> {
41+
struct SinCos<1,float> {
3942
static void sinNPhiCosNPhi(float& sinNPhi, float& cosNPhi,
4043
float sinPhi, float cosPhi ) {
4144
sinNPhi = sinPhi;
4245
cosNPhi = cosPhi;
4346
}
4447
};
4548

49+
template <>
50+
struct SinCos<1,double> {
51+
static void sinNPhiCosNPhi(double& sinNPhi, double& cosNPhi,
52+
double sinPhi, double cosPhi ) {
53+
sinNPhi = sinPhi;
54+
cosNPhi = cosPhi;
55+
}
56+
};
57+
4658
int enzyme_dup;
4759
int enzyme_out;
4860
int enzyme_const;
@@ -64,9 +76,9 @@ struct Dihedral {
6476
extern double __enzyme_autodiff(void*, ...);
6577

6678
void dump_vec(const std::string& msg, float* deriv, size_t num) {
67-
std::cout << msg << " ";
79+
std::cout << std::setw(10) << msg << " ";
6880
for ( int i=0; i<num; i++ ) {
69-
std::cout << deriv[i] << " ";
81+
std::cout << std::setprecision(4) << std::fixed << std::setw(8) << deriv[i] << " ";
7082
}
7183
std::cout << "\n";
7284
}
@@ -75,19 +87,20 @@ void dump_vec(const std::string& msg, float* deriv, size_t num) {
7587
#define mysqrt(VAR) (std::sqrt(VAR))
7688
#define reciprocal(VAR) (1.0/VAR)
7789

90+
template <typename Real>
7891
float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos) {
7992
#define USE_EXPLICIT_DECLARES 1
80-
#define DECLARE_FLOAT(VAR) float VAR;
81-
float EraseLinearDihedral;
82-
float sinPhase;
83-
float cosPhase;
84-
float V;
85-
float DN;
93+
#define DECLARE_FLOAT(VAR) Real VAR;
94+
Real EraseLinearDihedral;
95+
Real sinPhase;
96+
Real cosPhase;
97+
Real V;
98+
Real DN;
8699
int IN;
87-
float x1, y1, z1;
88-
float x2, y2, z2;
89-
float x3, y3, z3;
90-
float x4, y4, z4;
100+
Real x1, y1, z1;
101+
Real x2, y2, z2;
102+
Real x3, y3, z3;
103+
Real x4, y4, z4;
91104
#include "_DihedralEnergy_termDeclares.cc"
92105
#define DIHEDRAL_SET_PARAMETER(VAR) { VAR = dihedral->VAR; }
93106
#define DIHEDRAL_SET_POSITION(VAR,IDX,OFFSET) { VAR = pos[IDX+OFFSET]; }
@@ -96,9 +109,9 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
96109
int I2;
97110
int I3;
98111
int I4;
99-
float result = 0.0;
100-
float SinNPhi;
101-
float CosNPhi;
112+
Real result = 0.0;
113+
Real SinNPhi;
114+
Real CosNPhi;
102115
for ( auto dihedral = dihedral_begin; dihedral < dihedral_end; dihedral++ ) {
103116
#include "_DihedralEnergy_termCode.cc"
104117
}
@@ -109,8 +122,9 @@ float dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* p
109122
return result;
110123
}
111124

125+
template <typename Real>
112126
void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos, float* deriv ) {
113-
__enzyme_autodiff( (void*)dihedral_energy,
127+
__enzyme_autodiff( (void*)dihedral_energy<Real>,
114128
enzyme_const, dihedral_begin,
115129
enzyme_const, dihedral_end,
116130
enzyme_dup, pos, deriv );
@@ -119,7 +133,7 @@ void grad_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, floa
119133
void finite_diff_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_end, float* pos, float* deriv)
120134
{
121135
// Calculate the energy using the dihedral_energy function
122-
float energy_old = dihedral_energy(dihedral_begin, dihedral_end, pos);
136+
float energy_old = dihedral_energy<float>(dihedral_begin, dihedral_end, pos);
123137

124138
// Iterate over each position in the pos array
125139

@@ -135,26 +149,27 @@ void finite_diff_dihedral_energy(Dihedral* dihedral_begin, Dihedral* dihedral_en
135149
posp[i] = pos[i] + TENM3;
136150
posm[i] = pos[i] - TENM3;
137151

138-
float energy_new_p = dihedral_energy(dihedral_begin, dihedral_end, posp);
139-
float energy_new_m = dihedral_energy(dihedral_begin, dihedral_end, posm);
152+
float energy_new_p = dihedral_energy<float>(dihedral_begin, dihedral_end, posp);
153+
float energy_new_m = dihedral_energy<float>(dihedral_begin, dihedral_end, posm);
140154

141155
// Calculate the derivative using finite differences
142156
deriv[i] = (energy_new_p - energy_new_m) / (2.0*TENM3);
143157
}
144158
}
145159

146-
float old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float* pos, float* deriv) {
160+
template <typename Real>
161+
Real old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float* pos, float* deriv) {
147162
#define USE_EXPLICIT_DECLARES 1
148-
#define DECLARE_FLOAT(VAR) float VAR;
149-
float sinPhase;
150-
float cosPhase;
151-
float V;
152-
float DN;
163+
#define DECLARE_FLOAT(VAR) Real VAR;
164+
Real sinPhase;
165+
Real cosPhase;
166+
Real V;
167+
Real DN;
153168
int IN;
154-
float x1, y1, z1;
155-
float x2, y2, z2;
156-
float x3, y3, z3;
157-
float x4, y4, z4;
169+
Real x1, y1, z1;
170+
Real x2, y2, z2;
171+
Real x3, y3, z3;
172+
Real x4, y4, z4;
158173
#include "_Dihedral_termDeclares.cc"
159174
#define DIHEDRAL_SET_PARAMETER(VAR) { VAR = dihedral->VAR; }
160175
#define DIHEDRAL_SET_POSITION(VAR,IDX,OFFSET) { VAR = pos[IDX+OFFSET]; }
@@ -164,12 +179,12 @@ float old_dihedral_energy(Dihedral* dihedral_begin, Dihedral*dihedral_end, float
164179
int I2;
165180
int I3;
166181
int I4;
167-
float SinNPhi;
168-
float CosNPhi;
169-
float result = 0.0;
182+
Real SinNPhi;
183+
Real CosNPhi;
184+
Real result = 0.0;
170185
bool calcForce = true;
171186
#define DIHEDRAL_CALC_FORCE 1
172-
float EraseLinearDihedral;
187+
Real EraseLinearDihedral;
173188
for ( auto dihedral = dihedral_begin; dihedral<dihedral_end; dihedral++ ) {
174189
#include "_Dihedral_termCode.cc"
175190
}
@@ -185,56 +200,69 @@ void zeroVec(float* deriv, size_t num) {
185200
deriv[i] = 0.0;
186201
}
187202
}
203+
void copyVec(float* result, float* val, size_t num) {
204+
for (int i=0; i<num; i++ ) {
205+
result[i] = val[i];
206+
}
207+
}
208+
209+
188210
int main( int argc, const char* argv[] ) {
189211
float ANG = 20.0*0.0174533;
190212
float pos[12] = {0.0, 0.0, 1.0,
191-
0.0, 0.0, 0.0,
192-
1.0, 0.0, 0.0,
193-
1.0, (float)-sin(ANG), (float)cos(ANG) };
213+
0.0, 0.0, 0.0,
214+
1.0, 0.0, 0.0,
215+
1.0, (float)-sin(ANG), (float)cos(ANG) };
194216
float deriv[12];
195217
Dihedral dihedral[] = { {0.0, 1.0, 10.0, 2.0, 2, 0, 3, 6, 9 } };
196218

197219
dump_vec("pos", pos, 12);
198220
float energy = 0.0;
199221
std::string arg1(argv[1]);
200-
int donew = 0;
201-
202-
if(arg1 == "new")
203-
donew = 1;
204-
else if(arg1 == "fin")
205-
donew = 2;
206-
else
207-
donew = 3;
208222

209223
size_t num = atoi(argv[2]);
210-
if (donew == 1)// for new
211-
{
224+
if (arg1 == "new-float") {
212225
std::cout << "New method" << "\n";
213226
// energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
214227
for ( size_t nn = 0; nn<num; nn++ ) {
215228
zeroVec(deriv,12);
216-
grad_dihedral_energy( &dihedral[0], &dihedral[1], pos, deriv);
229+
grad_dihedral_energy<float>( &dihedral[0], &dihedral[1], pos, deriv);
217230
}
218-
}
219-
220-
else if (donew == 2)// for finite diff
221-
{
222-
for ( size_t nn = 0; nn<num; nn++ ) {
223-
zeroVec(deriv,12);
224-
finite_diff_dihedral_energy(&dihedral[0], &dihedral[1], pos, deriv);
225-
}
231+
} else if (arg1 == "new-double") {
232+
std::cout << "New method" << "\n";
233+
// energy = dihedral_energy( &dihedral[0], &dihedral[1], pos );
234+
for ( size_t nn = 0; nn<num; nn++ ) {
235+
zeroVec(deriv,12);
236+
grad_dihedral_energy<double>( &dihedral[0], &dihedral[1], pos, deriv);
237+
}
238+
} else if ( arg1 == "fin") {
239+
for ( size_t nn = 0; nn<num; nn++ ) {
240+
zeroVec(deriv,12);
241+
finite_diff_dihedral_energy(&dihedral[0], &dihedral[1], pos, deriv);
226242
}
227-
228-
229-
else//for old
230-
{
243+
} else if (arg1 == "old-float") {
244+
float fenergy;
245+
float fderiv[12];
246+
float fpos[12];
247+
copyVec(fpos,pos,12);
231248
std::cout << "Old method" << "\n";
232249
for ( size_t nn = 0; nn<num; nn++ ) {
233-
zeroVec(deriv,12);
234-
energy = old_dihedral_energy( &dihedral[0], &dihedral[1], pos, deriv );
250+
zeroVec(fderiv,12);
251+
energy = old_dihedral_energy<float>( &dihedral[0], &dihedral[1], fpos, fderiv );
235252
}
253+
std::cout << "Energy = " << fenergy << "\n";
254+
dump_vec("deriv", fderiv, 12);
255+
} else if (arg1 == "old-double") {
256+
float fenergy;
257+
float fderiv[12];
258+
float fpos[12];
259+
copyVec(fpos,pos,12);
260+
std::cout << "Old method" << "\n";
261+
for ( size_t nn = 0; nn<num; nn++ ) {
262+
zeroVec(fderiv,12);
263+
energy = old_dihedral_energy<double>( &dihedral[0], &dihedral[1], fpos, fderiv );
264+
}
265+
std::cout << "Energy = " << fenergy << "\n";
266+
dump_vec("deriv", fderiv, 12);
236267
}
237-
std::cout << "Energy = " << energy << "\n";
238-
dump_vec("deriv", deriv, 12);
239-
240268
}

tlm3_dihedral/_DihedralEnergy_termCode.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,22 @@
102102
/*SinNPhi = mathSinNPhi[IN,SinPhi,CosPhi];*/
103103
switch(IN) {
104104
case 1:
105-
SinCos<1>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
105+
SinCos<1,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
106106
break;
107107
case 2:
108-
SinCos<2>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
108+
SinCos<2,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
109109
break;
110110
case 3:
111-
SinCos<3>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
111+
SinCos<3,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
112112
break;
113113
case 4:
114-
SinCos<4>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
114+
SinCos<4,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
115115
break;
116116
case 5:
117-
SinCos<5>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
117+
SinCos<5,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
118118
break;
119119
case 6:
120-
SinCos<6>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
120+
SinCos<6,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
121121
break;
122122
};
123123
tx135 = CosNPhi*cosPhase; /* rule 103 */

tlm3_dihedral/_Dihedral_termCode.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,22 @@
102102
/*SinNPhi = mathSinNPhi[IN,SinPhi,CosPhi];*/
103103
switch(IN) {
104104
case 1:
105-
SinCos<1>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
105+
SinCos<1,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
106106
break;
107107
case 2:
108-
SinCos<2>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
108+
SinCos<2,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
109109
break;
110110
case 3:
111-
SinCos<3>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
111+
SinCos<3,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
112112
break;
113113
case 4:
114-
SinCos<4>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
114+
SinCos<4,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
115115
break;
116116
case 5:
117-
SinCos<5>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
117+
SinCos<5,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
118118
break;
119119
case 6:
120-
SinCos<6>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
120+
SinCos<6,Real>::sinNPhiCosNPhi(SinNPhi,CosNPhi,SinPhi,CosPhi);
121121
break;
122122
};
123123
tx973 = CosNPhi*cosPhase; /* rule 103 */

tlm3_dihedral/makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ all-Dihed: cmp-Dihed ad-Dihed opt-Dihed lower-Dihed
66
echo done
77

88
cmp-Dihed:
9-
clangi-15 Dihed.cc -S -emit-llvm -o Dihed-cmp.ll -O2 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops
9+
clang-15 Dihed.cc -S -emit-llvm -o Dihed-cmp.ll -O2 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -ftemplate-depth=20
1010

1111
ad-Dihed:
1212
opt-15 Dihed-cmp.ll -enable-new-pm=0 -load=$(ENZYME_so) --enzyme -o Dihed-ad.ll -S

0 commit comments

Comments
 (0)