Skip to content

Commit 8e54384

Browse files
authored
[Tensornet] Support trajectory simulation for unitary mixture noise channels (NVIDIA#2520)
* Initial work on adding support for cutensornetStateApplyUnitaryChannel Signed-off-by: Thien Nguyen <[email protected]> * Enable some noise test cases on tensornet Signed-off-by: Thien Nguyen <[email protected]> * Fix a copy-and-paste error Signed-off-by: Thien Nguyen <[email protected]> * support unitary mixture channel detection and enable more tests Signed-off-by: Thien Nguyen <[email protected]> * MPS trajectory: we need to compute the MPS factorization for each trajectory Signed-off-by: Thien Nguyen <[email protected]> * Split cutensornetStateFinalizeMPS and (cutensornetStatePrepare + cutensornetStateCompute) The first one is only needed once for trajectory simulation. Signed-off-by: Thien Nguyen <[email protected]> * Add trajectories to observe Signed-off-by: Thien Nguyen <[email protected]> * Handle unitary channels in all code paths Signed-off-by: Thien Nguyen <[email protected]> * Code format Signed-off-by: Thien Nguyen <[email protected]> * Reduce test time Signed-off-by: Thien Nguyen <[email protected]> * Update cutensornet version requirement Signed-off-by: Thien Nguyen <[email protected]> * Add sampler cache for MPS trajectory Signed-off-by: Thien Nguyen <[email protected]> * Add cache workspace mem Signed-off-by: Thien Nguyen <[email protected]> * Add trajectory support to non-path-reuse path merging from main Signed-off-by: Thien Nguyen <[email protected]> * Update some of the cutensornet DEPRECATED enums Signed-off-by: Thien Nguyen <[email protected]> * Make number of hyper sample configurable Signed-off-by: Thien Nguyen <[email protected]> * Docs update Signed-off-by: Thien Nguyen <[email protected]> * CR: Correct Pauli Y matrix Signed-off-by: Thien Nguyen <[email protected]> * CR: refactor SimulatorTensorNetBase::applyNoiseChannel Signed-off-by: Thien Nguyen <[email protected]> * CR: code refactor in MPS implementation Signed-off-by: Thien Nguyen <[email protected]> * Add an exact output state vec check for tensornet to check matrix data Signed-off-by: Thien Nguyen <[email protected]> * CR: Add a code comment for a helper function Signed-off-by: Thien Nguyen <[email protected]> --------- Signed-off-by: Thien Nguyen <[email protected]>
1 parent 8b37cb3 commit 8e54384

File tree

13 files changed

+994
-257
lines changed

13 files changed

+994
-257
lines changed

docs/sphinx/using/backends/sims/tnsims.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ Specific aspects of the simulation can be configured by setting the following of
8181
* **`CUDA_VISIBLE_DEVICES=X`**: Makes the process only see GPU X on multi-GPU nodes. Each MPI process must only see its own dedicated GPU. For example, if you run 8 MPI processes on a DGX system with 8 GPUs, each MPI process should be assigned its own dedicated GPU via `CUDA_VISIBLE_DEVICES` when invoking `mpiexec` (or `mpirun`) commands.
8282
* **`OMP_PLACES=cores`**: Set this environment variable to improve CPU parallelization.
8383
* **`OMP_NUM_THREADS=X`**: To enable CPU parallelization, set X to `NUMBER_OF_CORES_PER_NODE/NUMBER_OF_GPUS_PER_NODE`.
84+
* **`CUDAQ_TENSORNET_CONTROLLED_RANK=X`**: Specify the number of controlled qubits whereby the full tensor body of the controlled gate is expanded. If the number of controlled qubits is greater than this value, the gate is applied as a controlled tensor operator to the tensor network state. Default value is 1.
85+
* **`CUDAQ_TENSORNET_OBSERVE_CONTRACT_PATH_REUSE=X`**: Set this environment variable to `TRUE` (`ON`) or `FALSE` (`OFF`) to enable or disable contraction path reuse when computing expectation values. Default is `OFF`.
86+
* **`CUDAQ_TENSORNET_NUM_HYPER_SAMPLES=X`**: Specify the number of hyper samples used in the tensor network contraction path finder. Default value is 8 if not specified.
8487

8588
.. note::
8689

runtime/common/NoiseModel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ kraus_channel &kraus_channel::operator=(const kraus_channel &other) {
9797
return *this;
9898
}
9999

100-
std::vector<kraus_op> kraus_channel::get_ops() { return ops; }
100+
std::vector<kraus_op> kraus_channel::get_ops() const { return ops; }
101101
void kraus_channel::push_back(kraus_op op) { ops.push_back(op); }
102102

103103
void noise_model::add_channel(const std::string &quantumOp,

runtime/common/NoiseModel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class kraus_channel {
185185
kraus_channel &operator=(const kraus_channel &other);
186186

187187
/// @brief Return all kraus_ops in this channel
188-
std::vector<kraus_op> get_ops();
188+
std::vector<kraus_op> get_ops() const;
189189

190190
/// @brief Add a kraus_op to this channel.
191191
void push_back(kraus_op op);

runtime/nvqir/cutensornet/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ set(CUTENSORNET_PATCH ${CMAKE_MATCH_1})
6060

6161
set(CUTENSORNET_VERSION ${CUTENSORNET_MAJOR}.${CUTENSORNET_MINOR}.${CUTENSORNET_PATCH})
6262
message(STATUS "Found cutensornet version: ${CUTENSORNET_VERSION}")
63-
# We need cutensornet v2.5.0+
64-
if (${CUTENSORNET_VERSION} VERSION_GREATER_EQUAL "2.5")
63+
# We need cutensornet v2.6.0+ (cutensornetStateApplyUnitaryChannel)
64+
if (${CUTENSORNET_VERSION} VERSION_GREATER_EQUAL "2.6")
6565
set (BASE_TENSOR_BACKEND_SRS
6666
simulator_cutensornet.cpp
6767
tensornet_spin_op.cpp

runtime/nvqir/cutensornet/simulator_cutensornet.cpp

Lines changed: 213 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@
1111
#include "cutensornet.h"
1212
#include "tensornet_spin_op.h"
1313

14+
namespace {
15+
const std::vector<std::complex<double>> matPauliI = {
16+
{1.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {1.0, 0.0}};
17+
const std::vector<std::complex<double>> matPauliX{
18+
{0.0, 0.0}, {1.0, 0.0}, {1.0, 0.0}, {0.0, 0.0}};
19+
const std::vector<std::complex<double>> matPauliY{
20+
{0.0, 0.0}, {0.0, -1.0}, {0.0, 1.0}, {0.0, 0.0}};
21+
const std::vector<std::complex<double>> matPauliZ{
22+
{1.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {-1.0, 0.0}};
23+
} // namespace
24+
1425
namespace nvqir {
1526

1627
SimulatorTensorNetBase::SimulatorTensorNetBase()
@@ -137,6 +148,201 @@ void SimulatorTensorNetBase::applyGate(const GateApplicationTask &task) {
137148
}
138149
}
139150

151+
// Helper to check whether a matrix is a scaled unitary matrix, i.e., `k * U`
152+
// where U is a unitary matrix. If so, it also returns the `k` factor.
153+
// Otherwise, return a nullopt.
154+
template <typename T>
155+
std::optional<double> isScaledUnitary(const std::vector<std::complex<T>> &mat,
156+
double eps) {
157+
typedef Eigen::Matrix<std::complex<T>, Eigen::Dynamic, Eigen::Dynamic,
158+
Eigen::RowMajor>
159+
RowMajorMatTy;
160+
const int dim = std::log2(mat.size());
161+
Eigen::Map<const RowMajorMatTy> kMat(mat.data(), dim, dim);
162+
if (kMat.isZero())
163+
return std::nullopt;
164+
// Check that (K_dag * K) is a scaled identity matrix
165+
// i.e., the K matrix is a scaled unitary.
166+
auto kdK = kMat.adjoint() * kMat;
167+
if (!kdK.isDiagonal())
168+
return std::nullopt;
169+
// First element
170+
std::complex<T> val = kdK(0, 0);
171+
if (std::abs(val) > eps && std::abs(val.imag()) < eps) {
172+
auto scaledKdK = (std::complex<T>{1.0} / val) * kdK;
173+
if (scaledKdK.isIdentity())
174+
return val.real();
175+
}
176+
return std::nullopt;
177+
}
178+
179+
std::optional<std::pair<
180+
std::vector<double>,
181+
std::vector<std::vector<std::complex<
182+
double>>>>> static computeUnitaryMixture(const std::
183+
vector<std::vector<
184+
std::complex<double>>>
185+
&krausOps,
186+
double tol = 1e-6) {
187+
std::vector<double> probs;
188+
std::vector<std::vector<std::complex<double>>> mats;
189+
const auto scaleMat = [](const std::vector<std::complex<double>> &mat,
190+
double scaleFactor) {
191+
std::vector<std::complex<double>> scaledMat = mat;
192+
for (auto &x : scaledMat)
193+
x /= scaleFactor;
194+
return scaledMat;
195+
};
196+
for (const auto &op : krausOps) {
197+
const auto scaledFactor = isScaledUnitary(op, tol);
198+
if (!scaledFactor.has_value())
199+
return std::nullopt;
200+
probs.emplace_back(scaledFactor.value());
201+
mats.emplace_back(scaleMat(op, scaledFactor.value()));
202+
}
203+
204+
if (std::abs(1.0 - std::reduce(probs.begin(), probs.end())) > tol)
205+
return std::nullopt;
206+
207+
return std::make_pair(probs, mats);
208+
}
209+
210+
// Helper to look up a device memory pointer from a cache.
211+
// If not found, allocate a new device memory buffer and put it to the cache.
212+
static void *
213+
getOrCacheMat(const std::string &key,
214+
const std::vector<std::complex<double>> &mat,
215+
std::unordered_map<std::string, void *> &gateDeviceMemCache) {
216+
const auto iter = gateDeviceMemCache.find(key);
217+
218+
if (iter == gateDeviceMemCache.end()) {
219+
void *dMem = allocateGateMatrix(mat);
220+
gateDeviceMemCache[key] = dMem;
221+
return dMem;
222+
}
223+
return iter->second;
224+
};
225+
226+
void SimulatorTensorNetBase::applyKrausChannel(
227+
const std::vector<int32_t> &qubits,
228+
const cudaq::kraus_channel &krausChannel) {
229+
LOG_API_TIME();
230+
switch (krausChannel.noise_type) {
231+
case cudaq::noise_model_type::depolarization_channel: {
232+
if (krausChannel.parameters.size() != 1)
233+
throw std::runtime_error(
234+
fmt::format("Invalid parameters for a depolarization channel. "
235+
"Expecting 1 parameter, got {}.",
236+
krausChannel.parameters.size()));
237+
const std::vector<void *> channelMats{
238+
getOrCacheMat("PauliI", matPauliI, m_gateDeviceMemCache),
239+
getOrCacheMat("PauliX", matPauliX, m_gateDeviceMemCache),
240+
getOrCacheMat("PauliY", matPauliY, m_gateDeviceMemCache),
241+
getOrCacheMat("PauliZ", matPauliZ, m_gateDeviceMemCache)};
242+
const double p = krausChannel.parameters[0];
243+
const std::vector<double> probabilities = {1 - p, p / 3., p / 3., p / 3.};
244+
m_state->applyUnitaryChannel(qubits, channelMats, probabilities);
245+
break;
246+
}
247+
case cudaq::noise_model_type::bit_flip_channel: {
248+
if (krausChannel.parameters.size() != 1)
249+
throw std::runtime_error(
250+
fmt::format("Invalid parameters for a bit-flip channel. "
251+
"Expecting 1 parameter, got {}.",
252+
krausChannel.parameters.size()));
253+
254+
const std::vector<void *> channelMats{
255+
getOrCacheMat("PauliI", matPauliI, m_gateDeviceMemCache),
256+
getOrCacheMat("PauliX", matPauliX, m_gateDeviceMemCache)};
257+
const double p = krausChannel.parameters[0];
258+
const std::vector<double> probabilities = {1 - p, p};
259+
m_state->applyUnitaryChannel(qubits, channelMats, probabilities);
260+
break;
261+
}
262+
case cudaq::noise_model_type::phase_flip_channel: {
263+
if (krausChannel.parameters.size() != 1)
264+
throw std::runtime_error(
265+
fmt::format("Invalid parameters for a phase-flip channel. "
266+
"Expecting 1 parameter, got {}.",
267+
krausChannel.parameters.size()));
268+
269+
const std::vector<void *> channelMats{
270+
getOrCacheMat("PauliI", matPauliI, m_gateDeviceMemCache),
271+
getOrCacheMat("PauliZ", matPauliZ, m_gateDeviceMemCache)};
272+
const double p = krausChannel.parameters[0];
273+
const std::vector<double> probabilities = {1 - p, p};
274+
m_state->applyUnitaryChannel(qubits, channelMats, probabilities);
275+
break;
276+
}
277+
case cudaq::noise_model_type::amplitude_damping_channel: {
278+
if (krausChannel.parameters.size() != 1)
279+
throw std::runtime_error(
280+
fmt::format("Invalid parameters for a amplitude damping channel. "
281+
"Expecting 1 parameter, got {}.",
282+
krausChannel.parameters.size()));
283+
if (krausChannel.parameters[0] != 0.0)
284+
throw std::runtime_error("Non-unitary noise channels are not supported.");
285+
break;
286+
}
287+
case cudaq::noise_model_type::unknown: {
288+
std::vector<std::vector<std::complex<double>>> mats;
289+
for (const auto &op : krausChannel.get_ops())
290+
mats.emplace_back(op.data);
291+
auto asUnitaryMixture = computeUnitaryMixture(mats);
292+
if (asUnitaryMixture.has_value()) {
293+
auto &[probabilities, unitaries] = asUnitaryMixture.value();
294+
std::vector<void *> channelMats;
295+
for (const auto &mat : unitaries)
296+
channelMats.emplace_back(getOrCacheMat(
297+
"ScaledUnitary_" + std::to_string(vecComplexHash(mat)), mat,
298+
m_gateDeviceMemCache));
299+
m_state->applyUnitaryChannel(qubits, channelMats, probabilities);
300+
} else {
301+
throw std::runtime_error("Non-unitary noise channels are not supported.");
302+
}
303+
break;
304+
}
305+
default:
306+
throw std::runtime_error(
307+
"Unsupported noise model type: " +
308+
std::to_string(static_cast<int>(krausChannel.noise_type)));
309+
}
310+
}
311+
312+
void SimulatorTensorNetBase::applyNoiseChannel(
313+
const std::string_view gateName, const std::vector<std::size_t> &controls,
314+
const std::vector<std::size_t> &targets,
315+
const std::vector<double> &params) {
316+
LOG_API_TIME();
317+
// Do nothing if no execution context
318+
if (!executionContext)
319+
return;
320+
321+
// Do nothing if no noise model
322+
if (!executionContext->noiseModel)
323+
return;
324+
325+
// Get the name as a string
326+
std::string gName(gateName);
327+
std::vector<int32_t> qubits{controls.begin(), controls.end()};
328+
qubits.insert(qubits.end(), targets.begin(), targets.end());
329+
330+
// Get the Kraus channels specified for this gate and qubits
331+
auto krausChannels = executionContext->noiseModel->get_channels(
332+
gName, targets, controls, params);
333+
334+
// If none, do nothing
335+
if (krausChannels.empty())
336+
return;
337+
338+
cudaq::info(
339+
"[SimulatorTensorNetBase] Applying {} kraus channels on qubits: {}",
340+
krausChannels.size(), qubits);
341+
342+
for (const auto &krausChannel : krausChannels)
343+
applyKrausChannel(qubits, krausChannel);
344+
}
345+
140346
/// @brief Reset the state of a given qubit to zero
141347
void SimulatorTensorNetBase::resetQubit(const std::size_t qubitIdx) {
142348
flushGateQueue();
@@ -225,10 +431,7 @@ cudaq::ExecutionResult
225431
SimulatorTensorNetBase::sample(const std::vector<std::size_t> &measuredBits,
226432
const int shots) {
227433
LOG_API_TIME();
228-
std::vector<int32_t> measuredBitIds;
229-
std::transform(measuredBits.begin(), measuredBits.end(),
230-
std::back_inserter(measuredBitIds),
231-
[](std::size_t idx) { return static_cast<int32_t>(idx); });
434+
std::vector<int32_t> measuredBitIds(measuredBits.begin(), measuredBits.end());
232435
if (shots < 1) {
233436
cudaq::spin_op::spin_op_term allZTerm(2 * m_state->getNumQubits(), 0);
234437
for (const auto &m : measuredBits)
@@ -239,7 +442,8 @@ SimulatorTensorNetBase::sample(const std::vector<std::size_t> &measuredBits,
239442
}
240443

241444
prepareQubitTensorState();
242-
const auto samples = m_state->sample(measuredBitIds, shots);
445+
const auto samples =
446+
m_state->sample(measuredBitIds, shots, requireCacheWorkspace());
243447
cudaq::ExecutionResult counts(samples);
244448
double expVal = 0.0;
245449
// Compute the expectation value from the counts
@@ -286,7 +490,8 @@ SimulatorTensorNetBase::observe(const cudaq::spin_op &ham) {
286490
// cutensornetNetworkOperator_t and compute the expectation value.
287491
TensorNetworkSpinOp spinOp(ham, m_cutnHandle);
288492
std::complex<double> expVal =
289-
m_state->computeExpVal(spinOp.getNetworkOperator());
493+
m_state->computeExpVal(spinOp.getNetworkOperator(),
494+
this->executionContext->numberTrajectories);
290495
expVal += spinOp.getIdentityTermOffset();
291496
return cudaq::observe_result(expVal.real(), ham,
292497
cudaq::sample_result(cudaq::ExecutionResult(
@@ -316,7 +521,8 @@ SimulatorTensorNetBase::observe(const cudaq::spin_op &ham) {
316521
});
317522

318523
// Compute the expectation value for all terms
319-
const auto termExpVals = m_state->computeExpVals(terms);
524+
const auto termExpVals = m_state->computeExpVals(
525+
terms, this->executionContext->numberTrajectories);
320526
std::complex<double> expVal = 0.0;
321527
// Construct per-term data in the final observe_result
322528
std::vector<cudaq::ExecutionResult> results;

runtime/nvqir/cutensornet/simulator_cutensornet.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ class SimulatorTensorNetBase : public nvqir::CircuitSimulatorBase<double> {
3030
/// @brief Apply quantum gate
3131
void applyGate(const GateApplicationTask &task) override;
3232

33+
/// @brief Apply a noise channel
34+
void applyNoiseChannel(const std::string_view gateName,
35+
const std::vector<std::size_t> &controls,
36+
const std::vector<std::size_t> &targets,
37+
const std::vector<double> &params) override;
38+
3339
// Override base calculateStateDim (we don't instantiate full state vector in
3440
// the tensornet backend). When the user want to retrieve the state vector, we
3541
// check if it is feasible to do so.
@@ -88,6 +94,15 @@ class SimulatorTensorNetBase : public nvqir::CircuitSimulatorBase<double> {
8894
/// @brief Query if direct expectation value calculation is enabled
8995
virtual bool canHandleObserve() override;
9096

97+
/// @brief Return true if this simulator can use cache workspace (e.g., for
98+
/// intermediate tensors)
99+
virtual bool requireCacheWorkspace() const = 0;
100+
101+
private:
102+
// Helper to apply a Kraus channel
103+
void applyKrausChannel(const std::vector<int32_t> &qubits,
104+
const cudaq::kraus_channel &channel);
105+
91106
protected:
92107
cutensornetHandle_t m_cutnHandle;
93108
std::unique_ptr<TensorNetState> m_state;

0 commit comments

Comments
 (0)