Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add frequencies as optional CRAB parameter #18

Merged
merged 13 commits into from
May 6, 2024
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ def qutip_qtrl_version():
# -- Options for intersphinx ---------------------------------------

intersphinx_mapping = {
"qutip": ("https://qutip.org/docs/latest/", None),
"qutip": ("https://qutip.readthedocs.io/en/latest/", None),
}
38 changes: 29 additions & 9 deletions src/qutip_qtrl/pulsegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,12 +1017,12 @@ def reset(self):
self.guess_pulse_func = None
self.apply_params()

def init_pulse(self, num_coeffs=None):
def init_pulse(self, num_coeffs=None, init_coeffs=None):
"""
Set the initial freq and coefficient values
"""
PulseGen.init_pulse(self)
self.init_coeffs(num_coeffs=num_coeffs)
self.init_coeffs(num_coeffs=num_coeffs, init_coeffs=init_coeffs)

if self.guess_pulse is not None:
self.init_guess_pulse()
Expand All @@ -1042,16 +1042,18 @@ def init_pulse(self, num_coeffs=None):
# self.guess_pulse = self.guess_pulsegen.gen_pulse()
# return self.guess_pulse

def init_coeffs(self, num_coeffs=None):
def init_coeffs(self, num_coeffs=None, init_coeffs=None):
"""
Generate the initial ceofficent values.
Generate or set the initial ceofficent values.

Parameters
----------
num_coeffs : integer
Number of coefficients used for each basis function
If given this overides the default and sets the attribute
of the same name.
init_coeffs : float array[num_coeffs * num_basis_funcs]
Typically this will be the initial basis coefficients.
"""
if num_coeffs:
self.num_coeffs = num_coeffs
Expand Down Expand Up @@ -1087,8 +1089,9 @@ def init_coeffs(self, num_coeffs=None):
self.num_coeffs, self.NUM_COEFFS_WARN_LVL
)
)

if self.randomize_coeffs:
if init_coeffs is not None:
self.set_coeffs(init_coeffs)
elif self.randomize_coeffs:
r = np.random.random([self.num_coeffs, self.num_basis_funcs])
self.coeffs = (2 * r - 1.0) * self.scaling
else:
Expand Down Expand Up @@ -1236,8 +1239,20 @@ class PulseGenCrabFourier(PulseGenCrab):
Frequencies for the basis functions
randomize_freqs : bool
If True (default) the some random offset is applied to the frequencies
fix_freqs : bool
If True (default) then the frequencies of the basis functions are fixed
and the number of basis functions is set to 2 (sin and cos).
If False then the frequencies are also optimised, adding an additional
parameter for each pair of basis functions.
"""

def __init__(self, dyn=None, num_coeffs=None, params=None, fix_freqs=True):
PulseGenCrab.__init__(self, dyn, num_coeffs, params)
self.fix_freqs = fix_freqs
if not self.fix_freqs:
# additional parameter for the frequency
self.num_basis_funcs += 1

def reset(self):
"""
reset attributes to default values
Expand All @@ -1246,11 +1261,13 @@ def reset(self):
self.freqs = None
self.randomize_freqs = True

def init_pulse(self, num_coeffs=None):
def init_pulse(self, num_coeffs=None, init_coeffs=None):
"""
Set the initial freq and coefficient values
"""
PulseGenCrab.init_pulse(self)
PulseGenCrab.init_pulse(
self, num_coeffs=num_coeffs, init_coeffs=init_coeffs
)

self.init_freqs()

Expand Down Expand Up @@ -1289,7 +1306,10 @@ def gen_pulse(self, coeffs=None):
pulse = np.zeros(self.num_tslots)

for i in range(self.num_coeffs):
phase = self.freqs[i] * self.time
if self.fix_freqs: # dont optimise frequencies
phase = self.freqs[i] * self.time
else: # optimise frequencies as part of the parameters
phase = self.coeffs[i, 2] * self.time
pulse += self.coeffs[i, 0] * np.sin(phase) + self.coeffs[
i, 1
] * np.cos(phase)
Expand Down
6 changes: 5 additions & 1 deletion src/qutip_qtrl/pulseoptim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,9 +2020,11 @@ def create_pulse_optimizer(
# Create a pulse generator for each ctrl
crab_pulse_params = None
num_coeffs = None
fix_freqs = True
init_coeff_scaling = None
if isinstance(alg_params, dict):
num_coeffs = alg_params.get("num_coeffs")
fix_freqs = alg_params.get("fix_frequency")
init_coeff_scaling = alg_params.get("init_coeff_scaling")
if "crab_pulse_params" in alg_params:
crab_pulse_params = alg_params.get("crab_pulse_params")
Expand All @@ -2042,7 +2044,9 @@ def create_pulse_optimizer(
optim.pulse_generator = []
for j in range(n_ctrls):
crab_pgen = pulsegen.PulseGenCrabFourier(
dyn=dyn, num_coeffs=num_coeffs
dyn=dyn,
num_coeffs=num_coeffs,
fix_freqs=fix_freqs,
)
if init_coeff_scaling is not None:
crab_pgen.scaling = init_coeff_scaling
Expand Down
34 changes: 34 additions & 0 deletions tests/test_control_pulseoptim.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,40 @@ def test_crab(self, propagation):
assert abs(result.fid_err) < tol
assert abs(result.final_amps[0, 0]) < tol, "Lead-in amplitude nonzero."

def test_init_pulsegencrab(self, system, propagation):
"""
Test the initialization of CRAB Fourier parameters
"""
# Setup the optimizer
system = _merge_kwargs(system, propagation)
optimizer = cpo.create_pulse_optimizer(
system.system,
system.controls,
system.initial,
system.target,
alg="CRAB",
**system.kwargs,
)
dynamics = optimizer.dynamics

init_amps = np.zeros([dynamics.num_tslots, dynamics.num_ctrls])
# Generate initial pulses for each control through generator
for i, pgen in enumerate(optimizer.pulse_generator):
num_coeffs = 8
vals = np.ones(num_coeffs * pgen.num_basis_funcs)
pgen.init_pulse(num_coeffs, init_coeffs=vals)
init_amps[:, i] = pgen.gen_pulse()

# Initialize the dynamics with the initial amplitudes
dynamics.initialize_controls(init_amps)

# Run the optimization
result = optimizer.run_optimization()
assert isinstance(result.fid_err, float)

# Check proper initialization of CRAB amplitudes
assert np.isclose(result.initial_amps, init_amps).all()


# The full object-orientated interface to the optimiser is rather complex. To
# attempt to simplify the test of the configuration loading, we break it down
Expand Down
Loading