Skip to content

Commit

Permalink
[Hack] Make Quad Fast
Browse files Browse the repository at this point in the history
Make the `Quad` fast by calculating constants that depend
on the element config and reference particle properties
before pushing all particles.
  • Loading branch information
ax3l committed Feb 12, 2025
1 parent 220531e commit 285fa94
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 30 deletions.
81 changes: 52 additions & 29 deletions src/elements/Quad.H
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ namespace impactx::elements
/** Push all particles */
using BeamOptic::operator();

void calc_constants (RefPart const & refpart)
{
using namespace amrex::literals; // for _rt and _prt

// length of the current slice
m_slice_ds = m_ds / nslice();

amrex::ParticleReal const pt_ref = refpart.pt;
// find beta*gamma^2
m_betgam2 = std::pow(pt_ref, 2) - 1.0_prt;

// assign intermediate parameter
m_step = m_slice_ds / std::sqrt(std::pow(pt_ref, 2) - 1.0_prt);

// compute phase advance per unit length in s (in rad/m)
m_omega = std::sqrt(std::abs(m_k));
m_sin_omega_ds = std::sin(m_omega*m_slice_ds);
m_cos_omega_ds = std::cos(m_omega*m_slice_ds);
m_sinh_omega_ds = std::sinh(m_omega*m_slice_ds);
m_cosh_omega_ds = std::cosh(m_omega*m_slice_ds);
}

/** This is a quad functor, so that a variable of this type can be used like a quad function.
*
* @param x particle position in x
Expand All @@ -96,24 +118,14 @@ namespace impactx::elements
amrex::ParticleReal & AMREX_RESTRICT py,
amrex::ParticleReal & AMREX_RESTRICT pt,
uint64_t & AMREX_RESTRICT idcpu,
RefPart const & refpart
[[maybe_unused]] RefPart const & refpart
) const
{
using namespace amrex::literals; // for _rt and _prt

// shift due to alignment errors of the element
shift_in(x, y, px, py);

// length of the current slice
amrex::ParticleReal const slice_ds = m_ds / nslice();

// access reference particle values to find beta*gamma^2
amrex::ParticleReal const pt_ref = refpart.pt;
amrex::ParticleReal const betgam2 = std::pow(pt_ref, 2) - 1.0_prt;

// compute phase advance per unit length in s (in rad/m)
amrex::ParticleReal const omega = std::sqrt(std::abs(m_k));

// intialize output values
amrex::ParticleReal xout = x;
amrex::ParticleReal yout = y;
Expand All @@ -124,31 +136,31 @@ namespace impactx::elements

if (m_k > 0.0) {
// advance position and momentum (focusing quad)
xout = std::cos(omega*slice_ds)*x + std::sin(omega*slice_ds)/omega*px;
pxout = -omega*std::sin(omega*slice_ds)*x + std::cos(omega*slice_ds)*px;
xout = m_cos_omega_ds*x + m_sin_omega_ds/m_omega*px;
pxout = -m_omega*m_sin_omega_ds*x + m_cos_omega_ds*px;

yout = std::cosh(omega*slice_ds)*y + std::sinh(omega*slice_ds)/omega*py;
pyout = omega*std::sinh(omega*slice_ds)*y + std::cosh(omega*slice_ds)*py;
yout = m_cosh_omega_ds*y + m_sinh_omega_ds/m_omega*py;
pyout = m_omega*m_sinh_omega_ds*y + m_cosh_omega_ds*py;

tout = t + (slice_ds/betgam2)*pt;
tout = t + (m_slice_ds/m_betgam2)*pt;
// ptout = pt;
} else if (m_k < 0.0) {
// advance position and momentum (defocusing quad)
xout = std::cosh(omega*slice_ds)*x + std::sinh(omega*slice_ds)/omega*px;
pxout = omega*std::sinh(omega*slice_ds)*x + std::cosh(omega*slice_ds)*px;
xout = m_cosh_omega_ds*x + m_sinh_omega_ds/m_omega*px;
pxout = m_omega*m_sinh_omega_ds*x + m_cosh_omega_ds*px;

yout = std::cos(omega*slice_ds)*y + std::sin(omega*slice_ds)/omega*py;
pyout = -omega*std::sin(omega*slice_ds)*y + std::cos(omega*slice_ds)*py;
yout = m_cos_omega_ds*y + m_sin_omega_ds/m_omega*py;
pyout = -m_omega*m_sin_omega_ds*y + m_cos_omega_ds*py;

tout = t + (slice_ds/betgam2)*pt;
tout = t + (m_slice_ds/m_betgam2)*pt;
// ptout = pt;
} else {
// advance position and momentum (zero strength = drift)
xout = x + slice_ds * px;
xout = x + m_slice_ds * px;
// pxout = px;
yout = y + slice_ds * py;
yout = y + m_slice_ds * py;
// pyout = py;
tout = t + (slice_ds/betgam2) * pt;
tout = t + (m_slice_ds/m_betgam2) * pt;
// ptout = pt;

}
Expand All @@ -173,7 +185,7 @@ namespace impactx::elements
* @param[in,out] refpart reference particle
*/
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void operator() (RefPart & AMREX_RESTRICT refpart) const {
void operator() (RefPart & AMREX_RESTRICT refpart) const { // TODO: update as well, but needs more careful placement of calc_constants

using namespace amrex::literals; // for _rt and _prt

Expand Down Expand Up @@ -210,7 +222,7 @@ namespace impactx::elements
*/
AMREX_GPU_HOST AMREX_FORCE_INLINE
Map6x6
transport_map (RefPart const & AMREX_RESTRICT refpart) const
transport_map (RefPart const & AMREX_RESTRICT refpart) const // TODO: update as well, but needs more careful placement of calc_constants
{
using namespace amrex::literals; // for _rt and _prt

Expand Down Expand Up @@ -248,9 +260,9 @@ namespace impactx::elements
R(4,4) = std::cos(omega*slice_ds);
R(5,6) = slice_ds/betgam2;
} else {
R(1,2) = slice_ds;
R(3,4) = slice_ds;
R(5,6) = slice_ds / betgam2;
R(1,2) = m_slice_ds;
R(3,4) = m_slice_ds;
R(5,6) = m_slice_ds / betgam2;
}

return R;
Expand All @@ -260,6 +272,17 @@ namespace impactx::elements
using LinearTransport::operator();

amrex::ParticleReal m_k; //! quadrupole strength in 1/m

private:
// constants
amrex::ParticleReal m_step; //! ...
amrex::ParticleReal m_slice_ds; //! m_ds / nslice();
amrex::ParticleReal m_betgam2; //! beta*gamma^2
amrex::ParticleReal m_omega; //! std::sqrt(std::abs(m_k)) compute phase advance per unit length in s (in rad/m)
amrex::ParticleReal m_sin_omega_ds; //! std::sin(omega*slice_ds)
amrex::ParticleReal m_cos_omega_ds; //! std::cos(omega*slice_ds)
amrex::ParticleReal m_sinh_omega_ds; //! std::sinh(omega*slice_ds)
amrex::ParticleReal m_cosh_omega_ds; //! std::cosh(omega*slice_ds)
};

} // namespace impactx
Expand Down
15 changes: 14 additions & 1 deletion src/elements/mixin/beamoptic.H
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@
#include <type_traits>


namespace impactx::elements {
struct Quad;
}

namespace impactx::elements::mixin
{
namespace detail
{

/** Push a single particle through an element
*
* Note: we usually would just write a C++ lambda below in ParallelFor. But, due to restrictions
Expand Down Expand Up @@ -132,8 +137,16 @@ namespace detail

uint64_t* const AMREX_RESTRICT part_idcpu = pti.GetStructOfArrays().GetIdCPUData().dataPtr();

detail::PushSingleParticle<T_Element> const pushSingleParticle(
if constexpr (std::is_same_v<std::decay_t<T_Element>, elements::Quad>)
{
element.calc_constants(ref_part);
}

detail::PushSingleParticle<T_Element> pushSingleParticle(
element, part_x, part_y, part_t, part_px, part_py, part_pt, part_idcpu, ref_part);



// loop over beam particles in the box
amrex::ParallelFor(np, pushSingleParticle);
}
Expand Down

0 comments on commit 285fa94

Please sign in to comment.