Skip to content

Commit

Permalink
Make CFbend fast
Browse files Browse the repository at this point in the history
  • Loading branch information
ax3l committed Feb 14, 2025
1 parent 2eb9384 commit 55bc5e9
Showing 1 changed file with 96 additions and 49 deletions.
145 changes: 96 additions & 49 deletions src/elements/CFbend.H
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,56 @@ namespace impactx::elements
/** Push all particles */
using BeamOptic::operator();

/** Compute and cache the constants for the push.
*
* In particular, used to pre-compute and cache variables that are
* independent of the individually tracked particle.
*
* @param refpart reference particle
*/
void compute_constants (RefPart const & refpart)
{
using namespace amrex::literals; // for _rt and _prt

Alignment::compute_constants(refpart);

// length of the current slice
m_slice_ds = m_ds / nslice();

// find beta*gamma^2, beta
amrex::ParticleReal const betgam2 = std::pow(refpart.pt, 2) - 1.0_prt;
amrex::ParticleReal const bet = std::sqrt(betgam2 / (1.0_prt + betgam2));
m_ibetgam2 = 1.0_prt / betgam2;
amrex::ParticleReal const b2rc2 = std::pow(bet, 2) * std::pow(m_rc, 2);
m_igobr = 1.0_prt / ( m_gx * m_omega_x * b2rc2 );

// update horizontal and longitudinal phase space variables
m_gx = m_k + std::pow(m_rc,-2);
m_omega_x = std::sqrt(std::abs(m_gx));

// update vertical phase space variables
m_gy = -m_k;
m_omega_y = std::sqrt(std::abs(m_gy));

// trigonometry
auto const [sinx, cosx] = amrex::Math::sincos(m_omega_x * m_slice_ds);
m_sinx = sinx;
m_cosx = cosx;
m_sinhx = std::sinh(m_omega_x * m_slice_ds);
m_coshx = std::cosh(m_omega_x * m_slice_ds);
auto const [siny, cosy] = amrex::Math::sincos(m_omega_y * m_slice_ds);
m_siny = siny;
m_cosy = cosy;
m_sinhy = std::sinh(m_omega_y * m_slice_ds);
m_coshy = std::cosh(m_omega_y * m_slice_ds);

m_rgbrc = 1.0_prt / ( m_gx * bet * m_rc );
m_robrc = m_omega_x * bet * m_rc;
}

/** This is a cfbend functor, so that a variable of this type can be used like a cfbend function.
*
* The @see compute_constants method must be called before pushing particles through this operator.
*
* @param x particle position in x
* @param y particle position in y
Expand All @@ -90,7 +139,7 @@ namespace impactx::elements
* @param py particle momentum in y
* @param pt particle momentum in t
* @param idcpu particle global index
* @param refpart reference particle
* @param refpart reference particle (unused)
*/
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void operator() (
Expand All @@ -101,7 +150,7 @@ namespace impactx::elements
amrex::ParticleReal & AMREX_RESTRICT py,
amrex::ParticleReal & AMREX_RESTRICT pt,
uint64_t & AMREX_RESTRICT idcpu,
RefPart const & refpart
[[maybe_unused]] RefPart const & AMREX_RESTRICT refpart
) const
{
using namespace amrex::literals; // for _rt and _prt
Expand All @@ -119,67 +168,44 @@ namespace impactx::elements
amrex::ParticleReal pyout = py;
amrex::ParticleReal ptout = pt;

// length of the current slice
amrex::ParticleReal const slice_ds = m_ds / nslice();

// access reference particle values to find beta*gamma^2
amrex::ParticleReal const pt_ref = refpart.pt;
amrex::ParticleReal const betgam2 = std::pow(pt_ref, 2) - 1.0_prt;
amrex::ParticleReal const bet = std::sqrt(betgam2/(1.0_prt + betgam2));

// update horizontal and longitudinal phase space variables
amrex::ParticleReal const gx = m_k + std::pow(m_rc,-2);
amrex::ParticleReal const omegax = std::sqrt(std::abs(gx));

if(gx > 0.0) {
// calculate expensive terms once
auto const [sinx, cosx] = amrex::Math::sincos(omegax * slice_ds);
amrex::ParticleReal const r56 = slice_ds/betgam2
+ (sinx - omegax*slice_ds)/(gx*omegax * std::pow(bet,2) * std::pow(m_rc,2));
if (m_gx > 0.0)
{
amrex::ParticleReal const r56 = m_slice_ds * m_ibetgam2
+ (m_sinx - m_omega_x * m_slice_ds) * m_igobr;

// advance position and momentum (focusing)
x = cosx*xout + sinx/omegax*px - (1.0_prt - cosx)/(gx*bet*m_rc)*pt;
pxout = -omegax*sinx*xout + cosx*px - sinx/(omegax*bet*m_rc)*pt;
x = m_cosx * xout + m_sinx / m_omega_x * px - (1.0_prt - m_cosx) * m_rgbrc * pt;
pxout = -m_omega_x * m_sinx * xout + m_cosx * px - m_sinx * m_robrc * pt;

y = sinx/(omegax*bet*m_rc)*xout + (1.0_prt - cosx)/(gx*bet*m_rc)*px
+ tout + r56*pt;
y = m_sinx * m_robrc * xout + (1.0_prt - m_cosx) * m_rgbrc * px
+ tout + r56 * pt;
ptout = pt;
} else {
// calculate expensive terms once
amrex::ParticleReal const sinhx = std::sinh(omegax * slice_ds);
amrex::ParticleReal const coshx = std::cosh(omegax * slice_ds);
amrex::ParticleReal const r56 = slice_ds/betgam2
+ (sinhx - omegax*slice_ds)/(gx*omegax * std::pow(bet,2) * std::pow(m_rc,2));
} else
{
amrex::ParticleReal const r56 = m_slice_ds * m_ibetgam2
+ (m_sinhx - m_omega_x * m_slice_ds) * m_igobr;

// advance position and momentum (defocusing)
x = coshx*xout + sinhx/omegax*px - (1.0_prt - coshx)/(gx*bet*m_rc)*pt;
pxout = omegax*sinhx*xout + coshx*px - sinhx/(omegax*bet*m_rc)*pt;
x = m_coshx * xout + m_sinhx / m_omega_x * px - (1.0_prt - m_coshx) * m_rgbrc * pt;
pxout = m_omega_x * m_sinhx * xout + m_coshx * px - m_sinhx * m_robrc * pt;

t = sinhx/(omegax*bet*m_rc)*xout + (1.0_prt - coshx)/(gx*bet*m_rc)*px
+ tout + r56*pt;
t = m_sinhx * m_robrc * xout + (1.0_prt - m_coshx) * m_rgbrc * px
+ tout + r56 * pt;
ptout = pt;
}

// update vertical phase space variables
amrex::ParticleReal const gy = -m_k;
amrex::ParticleReal const omegay = std::sqrt(std::abs(gy));

if(gy > 0.0) {
// calculate expensive terms once
auto const [siny, cosy] = amrex::Math::sincos(omegay * slice_ds);

if (m_gy > 0.0)
{
// advance position and momentum (focusing)
y = cosy*yout + siny/omegay*py;
pyout = -omegay*siny*yout + cosy*py;

} else {
// calculate expensive terms once
amrex::ParticleReal const sinhy = std::sinh(omegay * slice_ds);
amrex::ParticleReal const coshy = std::cosh(omegay * slice_ds);

y = m_cosy * yout + m_siny / m_omega_y * py;
pyout = -m_omega_y * m_siny * yout + m_cosy * py;
} else
{
// advance position and momentum (defocusing)
y = coshy*yout + sinhy/omegay*py;
pyout = omegay*sinhy*yout + coshy*py;
y = m_coshy * yout + m_sinhy / m_omega_y * py;
pyout = m_omega_y * m_sinhy * yout + m_coshy * py;
}

// assign updated momenta
Expand Down Expand Up @@ -257,6 +283,27 @@ namespace impactx::elements

amrex::ParticleReal m_rc; //! bend radius in m
amrex::ParticleReal m_k; //! quadrupole strength in m^(-2)

private:
// constants that are independent of the individually tracked particle,
// see: compute_constants() to refresh
amrex::ParticleReal m_slice_ds; //! m_ds / nslice();
amrex::ParticleReal m_ibetgam2; //! 1 / (beta*gamma^2)
amrex::ParticleReal m_gx;
amrex::ParticleReal m_gy;
amrex::ParticleReal m_omega_x;
amrex::ParticleReal m_omega_y;
amrex::ParticleReal m_igobr;
amrex::ParticleReal m_sinx; //! sin(omegax*x)
amrex::ParticleReal m_cosx; //! cos(omegax*x)
amrex::ParticleReal m_sinhx; //! sinh(omegax*x)
amrex::ParticleReal m_coshx; //! cosh(omegax*x)
amrex::ParticleReal m_siny; //! sin(omegay*y)
amrex::ParticleReal m_cosy; //! cos(omegay*y)
amrex::ParticleReal m_sinhy; //! sinh(omegay*y)
amrex::ParticleReal m_coshy; //! cosh(omegay*y)
amrex::ParticleReal m_rgbrc;
amrex::ParticleReal m_robrc;
};

} // namespace impactx
Expand Down

0 comments on commit 55bc5e9

Please sign in to comment.