1
+ /* ***************************************************************-*- C++ -*-****
2
+ * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. *
3
+ * All rights reserved. *
4
+ * *
5
+ * This source code and the accompanying materials are made available under *
6
+ * the terms of the Apache License 2.0 which accompanies this distribution. *
7
+ ******************************************************************************/
8
+
9
+ #include " cudaq/qis/state.h"
10
+ #include " cudaq/utils/tensor.h"
11
+
12
+ #include < complex>
13
+ #include < functional>
14
+ #include < iostream>
15
+ #include < map>
16
+ #include < string>
17
+ #include < vector>
18
+
19
+ namespace cudaq {
20
+
21
+ // Limit the signature of the users callback function to accept a vector of ints
22
+ // for the degree of freedom dimensions, and a vector of complex doubles for the
23
+ // concrete parameter values.
24
+ using Func = std::function<tensor<std::complex<double >>(
25
+ std::map<int , int >, std::map<std::string, std::complex<double >>)>;
26
+
27
+ class CallbackFunction {
28
+ private:
29
+ // The user provided callback function that takes the degrees of
30
+ // freedom and a vector of complex parameters.
31
+ Func _callback_func;
32
+
33
+ public:
34
+ CallbackFunction () = default ;
35
+
36
+ template <typename Callable>
37
+ CallbackFunction (Callable &&callable) {
38
+ static_assert (
39
+ std::is_invocable_r_v<tensor<std::complex<double >>, Callable, std::map<int , int >,
40
+ std::map<std::string, std::complex<double >>>,
41
+ " Invalid callback function. Must have signature tensor<std::complex<double>>("
42
+ " std::map<int,int>, "
43
+ " std::map<std::string, std::complex<double>>)" );
44
+ _callback_func = std::forward<Callable>(callable);
45
+ }
46
+
47
+ // Copy constructor.
48
+ CallbackFunction (CallbackFunction &other) {
49
+ _callback_func = other._callback_func ;
50
+ }
51
+
52
+ CallbackFunction (const CallbackFunction &other) {
53
+ _callback_func = other._callback_func ;
54
+ }
55
+
56
+ tensor<std::complex<double >>
57
+ operator ()(std::map<int , int > degrees,
58
+ std::map<std::string, std::complex<double >> parameters) const {
59
+ return _callback_func (std::move (degrees), std::move (parameters));
60
+ }
61
+ };
62
+
63
+ using ScalarFunc = std::function<std::complex<double >(
64
+ std::map<std::string, std::complex<double >>)>;
65
+
66
+ // A scalar callback function does not need to accept the dimensions,
67
+ // therefore we will use a different function type for this specific class.
68
+ class ScalarCallbackFunction : CallbackFunction {
69
+ private:
70
+ // The user provided callback function that takes a vector of parameters.
71
+ ScalarFunc _callback_func;
72
+
73
+ public:
74
+ ScalarCallbackFunction () = default ;
75
+
76
+ template <typename Callable>
77
+ ScalarCallbackFunction (Callable &&callable) {
78
+ static_assert (
79
+ std::is_invocable_r_v<std::complex<double >, Callable,
80
+ std::map<std::string, std::complex<double >>>,
81
+ " Invalid callback function. Must have signature std::complex<double>("
82
+ " std::map<std::string, std::complex<double>>)" );
83
+ _callback_func = std::forward<Callable>(callable);
84
+ }
85
+
86
+ // Copy constructor.
87
+ ScalarCallbackFunction (ScalarCallbackFunction &other) {
88
+ _callback_func = other._callback_func ;
89
+ }
90
+
91
+ ScalarCallbackFunction (const ScalarCallbackFunction &other) {
92
+ _callback_func = other._callback_func ;
93
+ }
94
+
95
+ bool operator !() { return (!_callback_func); }
96
+
97
+ std::complex<double >
98
+ operator ()(std::map<std::string, std::complex<double >> parameters) const {
99
+ return _callback_func (std::move (parameters));
100
+ }
101
+ };
102
+
103
+ // / @brief Object used to give an error if a Definition of an elementary
104
+ // / or scalar operator is instantiated by other means than the `define`
105
+ // / class method.
106
+ class Definition {
107
+ public:
108
+ std::string id;
109
+
110
+ // The user-provided generator function should take a variable number of
111
+ // complex doubles for the parameters. It should return a
112
+ // `cudaq::tensor<std::complex<double>>` type representing the operator matrix.
113
+ CallbackFunction generator;
114
+
115
+ // Constructor.
116
+ Definition ();
117
+
118
+ // Destructor.
119
+ ~Definition ();
120
+
121
+ void create_definition (const std::string &operator_id,
122
+ std::map<int , int > expected_dimensions,
123
+ CallbackFunction &&create);
124
+
125
+ // To call the generator function
126
+ tensor<std::complex<double >> generate_matrix (
127
+ const std::map<int , int > °rees,
128
+ const std::map<std::string, std::complex<double >> ¶meters) const ;
129
+
130
+ private:
131
+ // Member variables
132
+ std::map<int , int > m_expected_dimensions;
133
+ };
134
+ } // namespace cudaq
0 commit comments