Skip to content

Commit 55bc5e9

Browse files
committed
Make CFbend fast
1 parent 2eb9384 commit 55bc5e9

File tree

1 file changed

+96
-49
lines changed

1 file changed

+96
-49
lines changed

src/elements/CFbend.H

Lines changed: 96 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,56 @@ namespace impactx::elements
8181
/** Push all particles */
8282
using BeamOptic::operator();
8383

84+
/** Compute and cache the constants for the push.
85+
*
86+
* In particular, used to pre-compute and cache variables that are
87+
* independent of the individually tracked particle.
88+
*
89+
* @param refpart reference particle
90+
*/
91+
void compute_constants (RefPart const & refpart)
92+
{
93+
using namespace amrex::literals; // for _rt and _prt
94+
95+
Alignment::compute_constants(refpart);
96+
97+
// length of the current slice
98+
m_slice_ds = m_ds / nslice();
99+
100+
// find beta*gamma^2, beta
101+
amrex::ParticleReal const betgam2 = std::pow(refpart.pt, 2) - 1.0_prt;
102+
amrex::ParticleReal const bet = std::sqrt(betgam2 / (1.0_prt + betgam2));
103+
m_ibetgam2 = 1.0_prt / betgam2;
104+
amrex::ParticleReal const b2rc2 = std::pow(bet, 2) * std::pow(m_rc, 2);
105+
m_igobr = 1.0_prt / ( m_gx * m_omega_x * b2rc2 );
106+
107+
// update horizontal and longitudinal phase space variables
108+
m_gx = m_k + std::pow(m_rc,-2);
109+
m_omega_x = std::sqrt(std::abs(m_gx));
110+
111+
// update vertical phase space variables
112+
m_gy = -m_k;
113+
m_omega_y = std::sqrt(std::abs(m_gy));
114+
115+
// trigonometry
116+
auto const [sinx, cosx] = amrex::Math::sincos(m_omega_x * m_slice_ds);
117+
m_sinx = sinx;
118+
m_cosx = cosx;
119+
m_sinhx = std::sinh(m_omega_x * m_slice_ds);
120+
m_coshx = std::cosh(m_omega_x * m_slice_ds);
121+
auto const [siny, cosy] = amrex::Math::sincos(m_omega_y * m_slice_ds);
122+
m_siny = siny;
123+
m_cosy = cosy;
124+
m_sinhy = std::sinh(m_omega_y * m_slice_ds);
125+
m_coshy = std::cosh(m_omega_y * m_slice_ds);
126+
127+
m_rgbrc = 1.0_prt / ( m_gx * bet * m_rc );
128+
m_robrc = m_omega_x * bet * m_rc;
129+
}
130+
84131
/** This is a cfbend functor, so that a variable of this type can be used like a cfbend function.
132+
*
133+
* The @see compute_constants method must be called before pushing particles through this operator.
85134
*
86135
* @param x particle position in x
87136
* @param y particle position in y
@@ -90,7 +139,7 @@ namespace impactx::elements
90139
* @param py particle momentum in y
91140
* @param pt particle momentum in t
92141
* @param idcpu particle global index
93-
* @param refpart reference particle
142+
* @param refpart reference particle (unused)
94143
*/
95144
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
96145
void operator() (
@@ -101,7 +150,7 @@ namespace impactx::elements
101150
amrex::ParticleReal & AMREX_RESTRICT py,
102151
amrex::ParticleReal & AMREX_RESTRICT pt,
103152
uint64_t & AMREX_RESTRICT idcpu,
104-
RefPart const & refpart
153+
[[maybe_unused]] RefPart const & AMREX_RESTRICT refpart
105154
) const
106155
{
107156
using namespace amrex::literals; // for _rt and _prt
@@ -119,67 +168,44 @@ namespace impactx::elements
119168
amrex::ParticleReal pyout = py;
120169
amrex::ParticleReal ptout = pt;
121170

122-
// length of the current slice
123-
amrex::ParticleReal const slice_ds = m_ds / nslice();
124-
125-
// access reference particle values to find beta*gamma^2
126-
amrex::ParticleReal const pt_ref = refpart.pt;
127-
amrex::ParticleReal const betgam2 = std::pow(pt_ref, 2) - 1.0_prt;
128-
amrex::ParticleReal const bet = std::sqrt(betgam2/(1.0_prt + betgam2));
129-
130171
// update horizontal and longitudinal phase space variables
131-
amrex::ParticleReal const gx = m_k + std::pow(m_rc,-2);
132-
amrex::ParticleReal const omegax = std::sqrt(std::abs(gx));
133-
134-
if(gx > 0.0) {
135-
// calculate expensive terms once
136-
auto const [sinx, cosx] = amrex::Math::sincos(omegax * slice_ds);
137-
amrex::ParticleReal const r56 = slice_ds/betgam2
138-
+ (sinx - omegax*slice_ds)/(gx*omegax * std::pow(bet,2) * std::pow(m_rc,2));
172+
if (m_gx > 0.0)
173+
{
174+
amrex::ParticleReal const r56 = m_slice_ds * m_ibetgam2
175+
+ (m_sinx - m_omega_x * m_slice_ds) * m_igobr;
139176

140177
// advance position and momentum (focusing)
141-
x = cosx*xout + sinx/omegax*px - (1.0_prt - cosx)/(gx*bet*m_rc)*pt;
142-
pxout = -omegax*sinx*xout + cosx*px - sinx/(omegax*bet*m_rc)*pt;
178+
x = m_cosx * xout + m_sinx / m_omega_x * px - (1.0_prt - m_cosx) * m_rgbrc * pt;
179+
pxout = -m_omega_x * m_sinx * xout + m_cosx * px - m_sinx * m_robrc * pt;
143180

144-
y = sinx/(omegax*bet*m_rc)*xout + (1.0_prt - cosx)/(gx*bet*m_rc)*px
145-
+ tout + r56*pt;
181+
y = m_sinx * m_robrc * xout + (1.0_prt - m_cosx) * m_rgbrc * px
182+
+ tout + r56 * pt;
146183
ptout = pt;
147-
} else {
148-
// calculate expensive terms once
149-
amrex::ParticleReal const sinhx = std::sinh(omegax * slice_ds);
150-
amrex::ParticleReal const coshx = std::cosh(omegax * slice_ds);
151-
amrex::ParticleReal const r56 = slice_ds/betgam2
152-
+ (sinhx - omegax*slice_ds)/(gx*omegax * std::pow(bet,2) * std::pow(m_rc,2));
184+
} else
185+
{
186+
amrex::ParticleReal const r56 = m_slice_ds * m_ibetgam2
187+
+ (m_sinhx - m_omega_x * m_slice_ds) * m_igobr;
153188

154189
// advance position and momentum (defocusing)
155-
x = coshx*xout + sinhx/omegax*px - (1.0_prt - coshx)/(gx*bet*m_rc)*pt;
156-
pxout = omegax*sinhx*xout + coshx*px - sinhx/(omegax*bet*m_rc)*pt;
190+
x = m_coshx * xout + m_sinhx / m_omega_x * px - (1.0_prt - m_coshx) * m_rgbrc * pt;
191+
pxout = m_omega_x * m_sinhx * xout + m_coshx * px - m_sinhx * m_robrc * pt;
157192

158-
t = sinhx/(omegax*bet*m_rc)*xout + (1.0_prt - coshx)/(gx*bet*m_rc)*px
159-
+ tout + r56*pt;
193+
t = m_sinhx * m_robrc * xout + (1.0_prt - m_coshx) * m_rgbrc * px
194+
+ tout + r56 * pt;
160195
ptout = pt;
161196
}
162197

163198
// update vertical phase space variables
164-
amrex::ParticleReal const gy = -m_k;
165-
amrex::ParticleReal const omegay = std::sqrt(std::abs(gy));
166-
167-
if(gy > 0.0) {
168-
// calculate expensive terms once
169-
auto const [siny, cosy] = amrex::Math::sincos(omegay * slice_ds);
170-
199+
if (m_gy > 0.0)
200+
{
171201
// advance position and momentum (focusing)
172-
y = cosy*yout + siny/omegay*py;
173-
pyout = -omegay*siny*yout + cosy*py;
174-
175-
} else {
176-
// calculate expensive terms once
177-
amrex::ParticleReal const sinhy = std::sinh(omegay * slice_ds);
178-
amrex::ParticleReal const coshy = std::cosh(omegay * slice_ds);
179-
202+
y = m_cosy * yout + m_siny / m_omega_y * py;
203+
pyout = -m_omega_y * m_siny * yout + m_cosy * py;
204+
} else
205+
{
180206
// advance position and momentum (defocusing)
181-
y = coshy*yout + sinhy/omegay*py;
182-
pyout = omegay*sinhy*yout + coshy*py;
207+
y = m_coshy * yout + m_sinhy / m_omega_y * py;
208+
pyout = m_omega_y * m_sinhy * yout + m_coshy * py;
183209
}
184210

185211
// assign updated momenta
@@ -257,6 +283,27 @@ namespace impactx::elements
257283

258284
amrex::ParticleReal m_rc; //! bend radius in m
259285
amrex::ParticleReal m_k; //! quadrupole strength in m^(-2)
286+
287+
private:
288+
// constants that are independent of the individually tracked particle,
289+
// see: compute_constants() to refresh
290+
amrex::ParticleReal m_slice_ds; //! m_ds / nslice();
291+
amrex::ParticleReal m_ibetgam2; //! 1 / (beta*gamma^2)
292+
amrex::ParticleReal m_gx;
293+
amrex::ParticleReal m_gy;
294+
amrex::ParticleReal m_omega_x;
295+
amrex::ParticleReal m_omega_y;
296+
amrex::ParticleReal m_igobr;
297+
amrex::ParticleReal m_sinx; //! sin(omegax*x)
298+
amrex::ParticleReal m_cosx; //! cos(omegax*x)
299+
amrex::ParticleReal m_sinhx; //! sinh(omegax*x)
300+
amrex::ParticleReal m_coshx; //! cosh(omegax*x)
301+
amrex::ParticleReal m_siny; //! sin(omegay*y)
302+
amrex::ParticleReal m_cosy; //! cos(omegay*y)
303+
amrex::ParticleReal m_sinhy; //! sinh(omegay*y)
304+
amrex::ParticleReal m_coshy; //! cosh(omegay*y)
305+
amrex::ParticleReal m_rgbrc;
306+
amrex::ParticleReal m_robrc;
260307
};
261308

262309
} // namespace impactx

0 commit comments

Comments
 (0)