Skip to content

Commit dc2f1cf

Browse files
orpuente-MSswernlicesarzc
authored
Add support for custom resets using the @Reset() attribute (#1981)
This PR adds support for custom resets using the `@Reset` attribute. It consists of the following changes: - Adding a `@Reset` attribute to the language. - Piping it down to FIR. - Using the attribute to generate the correct QIR, including the `"irreversible"` attribute. - Semantic validation passes in `qsc_passes` based on the `@Reset` attribute. - The right handling of reset operations during partial evaluation. - Unit tests. - Update previous tests to include the `#1` irreversible attribute in QIR when calling the Reset operation. --------- Co-authored-by: Stefan J. Wernli <[email protected]> Co-authored-by: César Zaragoza Cortés <[email protected]>
1 parent 852e858 commit dc2f1cf

File tree

40 files changed

+629
-59
lines changed

40 files changed

+629
-59
lines changed

compiler/qsc/src/codegen/tests.rs

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ mod adaptive_ri_profile {
10801080
10811081
declare void @__quantum__qis__x__body(%Qubit*)
10821082
1083-
declare void @__quantum__qis__reset__body(%Qubit*)
1083+
declare void @__quantum__qis__reset__body(%Qubit*) #1
10841084
10851085
declare void @__quantum__qis__mresetz__body(%Qubit*, %Result*) #1
10861086
@@ -1263,4 +1263,73 @@ mod adaptive_ri_profile {
12631263
!10 = !{i32 1, !"multiple_target_branching", i1 false}
12641264
"#]].assert_eq(&qir);
12651265
}
1266+
1267+
#[test]
1268+
fn custom_reset_generates_correct_qir() {
1269+
let source = "namespace Test {
1270+
operation Main() : Result {
1271+
use q = Qubit();
1272+
__quantum__qis__custom_reset__body(q);
1273+
M(q)
1274+
}
1275+
1276+
@Reset()
1277+
operation __quantum__qis__custom_reset__body(target: Qubit) : Unit {
1278+
body intrinsic;
1279+
}
1280+
}";
1281+
let sources = SourceMap::new([("test.qs".into(), source.into())], None);
1282+
let language_features = LanguageFeatures::default();
1283+
let capabilities = TargetCapabilityFlags::Adaptive
1284+
| TargetCapabilityFlags::QubitReset
1285+
| TargetCapabilityFlags::IntegerComputations;
1286+
1287+
let (std_id, store) = crate::compile::package_store_with_stdlib(capabilities);
1288+
let qir = get_qir(
1289+
sources,
1290+
language_features,
1291+
capabilities,
1292+
store,
1293+
&[(std_id, None)],
1294+
)
1295+
.expect("the input program set in the `source` variable should be valid Q#");
1296+
expect![[r#"
1297+
%Result = type opaque
1298+
%Qubit = type opaque
1299+
1300+
define void @ENTRYPOINT__main() #0 {
1301+
block_0:
1302+
call void @__quantum__qis__custom_reset__body(%Qubit* inttoptr (i64 0 to %Qubit*))
1303+
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*))
1304+
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null)
1305+
ret void
1306+
}
1307+
1308+
declare void @__quantum__qis__custom_reset__body(%Qubit*) #1
1309+
1310+
declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1
1311+
1312+
declare void @__quantum__rt__result_record_output(%Result*, i8*)
1313+
1314+
attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="adaptive_profile" "required_num_qubits"="1" "required_num_results"="1" }
1315+
attributes #1 = { "irreversible" }
1316+
1317+
; module flags
1318+
1319+
!llvm.module.flags = !{!0, !1, !2, !3, !4, !5, !6, !7, !8, !9, !10}
1320+
1321+
!0 = !{i32 1, !"qir_major_version", i32 1}
1322+
!1 = !{i32 7, !"qir_minor_version", i32 0}
1323+
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
1324+
!3 = !{i32 1, !"dynamic_result_management", i1 false}
1325+
!4 = !{i32 1, !"classical_ints", i1 true}
1326+
!5 = !{i32 1, !"qubit_resetting", i1 true}
1327+
!6 = !{i32 1, !"classical_floats", i1 false}
1328+
!7 = !{i32 1, !"backwards_branching", i1 false}
1329+
!8 = !{i32 1, !"classical_fixed_points", i1 false}
1330+
!9 = !{i32 1, !"user_functions", i1 false}
1331+
!10 = !{i32 1, !"multiple_target_branching", i1 false}
1332+
"#]]
1333+
.assert_eq(&qir);
1334+
}
12661335
}

compiler/qsc_codegen/src/qir.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,11 @@ impl ToQir<String> for rir::Callable {
504504
return format!(
505505
"declare {output_type} @{}({input_type}){}",
506506
self.name,
507-
if self.call_type == rir::CallableType::Measurement {
508-
// Measurement callables are a special case that needs the irreversable attribute.
507+
if matches!(
508+
self.call_type,
509+
rir::CallableType::Measurement | rir::CallableType::Reset
510+
) {
511+
// These callables are a special case that need the irreversable attribute.
509512
" #1"
510513
} else {
511514
""

compiler/qsc_fir/src/fir.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,8 @@ pub struct CallableDecl {
756756
pub functors: FunctorSetValue,
757757
/// The callable implementation.
758758
pub implementation: CallableImpl,
759+
/// The attributes of the callable, (e.g.: Measurement or Reset).
760+
pub attrs: Vec<Attr>,
759761
}
760762

761763
impl CallableDecl {
@@ -1521,6 +1523,8 @@ pub enum Attr {
15211523
EntryPoint,
15221524
/// Indicates that a callable is a measurement.
15231525
Measurement,
1526+
/// Indicates that a callable is a reset.
1527+
Reset,
15241528
}
15251529

15261530
/// A field.

compiler/qsc_frontend/src/closure.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ pub(super) fn lift(
151151
adj: None,
152152
ctl: None,
153153
ctl_adj: None,
154+
attrs: Vec::default(),
154155
};
155156

156157
(free_vars, callable)

compiler/qsc_frontend/src/lower.rs

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ pub(super) enum Error {
3030
#[error("invalid attribute arguments: expected {0}")]
3131
#[diagnostic(code("Qsc.LowerAst.InvalidAttrArgs"))]
3232
InvalidAttrArgs(String, #[label] Span),
33-
#[error("invalid use of the Measurement attribute on a function")]
33+
#[error("invalid use of the {0} attribute on a function")]
3434
#[diagnostic(help("try declaring the callable as an operation"))]
35-
#[diagnostic(code("Qsc.LowerAst.InvalidMeasurementAttrOnFunction"))]
36-
InvalidMeasurementAttrOnFunction(#[label] Span),
35+
#[diagnostic(code("Qsc.LowerAst.InvalidAttrOnFunction"))]
36+
InvalidAttrOnFunction(String, #[label] Span),
3737
#[error("missing callable body")]
3838
#[diagnostic(code("Qsc.LowerAst.MissingBody"))]
3939
MissingBody(#[label] Span),
@@ -368,6 +368,15 @@ impl With<'_> {
368368
None
369369
}
370370
},
371+
Ok(hir::Attr::Reset) => match &*attr.arg.kind {
372+
ast::ExprKind::Tuple(args) if args.is_empty() => Some(hir::Attr::Reset),
373+
_ => {
374+
self.lowerer
375+
.errors
376+
.push(Error::InvalidAttrArgs("()".to_string(), attr.arg.span));
377+
None
378+
}
379+
},
371380
Err(()) => {
372381
self.lowerer.errors.push(Error::UnknownAttr(
373382
attr.name.name.to_string(),
@@ -429,29 +438,40 @@ impl With<'_> {
429438
adj,
430439
ctl,
431440
ctl_adj,
441+
attrs: attrs.to_vec(),
442+
}
443+
}
444+
445+
fn check_invalid_attrs_on_function(&mut self, attrs: &[hir::Attr], span: Span) {
446+
const INVALID_ATTRS: [hir::Attr; 2] = [hir::Attr::Measurement, hir::Attr::Reset];
447+
448+
for invalid_attr in &INVALID_ATTRS {
449+
if attrs.contains(invalid_attr) {
450+
self.lowerer.errors.push(Error::InvalidAttrOnFunction(
451+
format!("{invalid_attr:?}"),
452+
span,
453+
));
454+
}
432455
}
433456
}
434457

435458
fn lower_callable_kind(
436459
&mut self,
437460
kind: ast::CallableKind,
438-
attrs: &[qsc_hir::hir::Attr],
461+
attrs: &[hir::Attr],
439462
span: Span,
440463
) -> hir::CallableKind {
441-
if attrs.contains(&qsc_hir::hir::Attr::Measurement) {
442-
match kind {
443-
ast::CallableKind::Operation => hir::CallableKind::Measurement,
444-
ast::CallableKind::Function => {
445-
self.lowerer
446-
.errors
447-
.push(Error::InvalidMeasurementAttrOnFunction(span));
448-
hir::CallableKind::Function
449-
}
464+
match kind {
465+
ast::CallableKind::Function => {
466+
self.check_invalid_attrs_on_function(attrs, span);
467+
hir::CallableKind::Function
450468
}
451-
} else {
452-
match kind {
453-
ast::CallableKind::Operation => hir::CallableKind::Operation,
454-
ast::CallableKind::Function => hir::CallableKind::Function,
469+
ast::CallableKind::Operation => {
470+
if attrs.contains(&hir::Attr::Measurement) {
471+
hir::CallableKind::Measurement
472+
} else {
473+
hir::CallableKind::Operation
474+
}
455475
}
456476
}
457477
}

compiler/qsc_frontend/src/lower/tests.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2221,7 +2221,8 @@ fn test_measurement_attr_on_function_issues_error() {
22212221
"#},
22222222
&expect![[r#"
22232223
[
2224-
InvalidMeasurementAttrOnFunction(
2224+
InvalidAttrOnFunction(
2225+
"Measurement",
22252226
Span {
22262227
lo: 49,
22272228
hi: 52,
@@ -2232,6 +2233,31 @@ fn test_measurement_attr_on_function_issues_error() {
22322233
);
22332234
}
22342235

2236+
#[test]
2237+
fn test_reset_attr_on_function_issues_error() {
2238+
check_errors(
2239+
indoc! {r#"
2240+
namespace Test {
2241+
@Reset()
2242+
function Foo(q: Qubit) : Unit {
2243+
body intrinsic;
2244+
}
2245+
}
2246+
"#},
2247+
&expect![[r#"
2248+
[
2249+
InvalidAttrOnFunction(
2250+
"Reset",
2251+
Span {
2252+
lo: 43,
2253+
hi: 46,
2254+
},
2255+
),
2256+
]
2257+
"#]],
2258+
);
2259+
}
2260+
22352261
#[test]
22362262
fn item_docs() {
22372263
check_hir(

compiler/qsc_hir/src/hir.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,8 @@ pub struct CallableDecl {
391391
pub ctl: Option<SpecDecl>,
392392
/// The controlled adjoint specialization.
393393
pub ctl_adj: Option<SpecDecl>,
394+
/// The attributes of the callable, (e.g.: Measurement or Reset).
395+
pub attrs: Vec<Attr>,
394396
}
395397

396398
impl CallableDecl {
@@ -1352,8 +1354,11 @@ pub enum Attr {
13521354
/// and any implementation should be ignored.
13531355
SimulatableIntrinsic,
13541356
/// Indicates that a callable is a measurement. This means that the operation will be marked as
1355-
/// "irreversible" in the generated QIR.
1357+
/// "irreversible" in the generated QIR, and output Result types will be moved to the arguments.
13561358
Measurement,
1359+
/// Indicates that a callable is a reset. This means that the operation will be marked as
1360+
/// "irreversible" in the generated QIR.
1361+
Reset,
13571362
}
13581363

13591364
impl FromStr for Attr {
@@ -1366,6 +1371,7 @@ impl FromStr for Attr {
13661371
"Unimplemented" => Ok(Self::Unimplemented),
13671372
"SimulatableIntrinsic" => Ok(Self::SimulatableIntrinsic),
13681373
"Measurement" => Ok(Self::Measurement),
1374+
"Reset" => Ok(Self::Reset),
13691375
_ => Err(()),
13701376
}
13711377
}

compiler/qsc_lowerer/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ impl Lowerer {
282282
};
283283
CallableImpl::Spec(specialized_implementation)
284284
};
285+
let attrs = lower_attrs(&decl.attrs);
285286

286287
self.assigner.reset_local();
287288
self.locals.clear();
@@ -299,6 +300,7 @@ impl Lowerer {
299300
output,
300301
functors,
301302
implementation,
303+
attrs,
302304
}
303305
}
304306

@@ -888,6 +890,7 @@ fn lower_attrs(attrs: &[hir::Attr]) -> Vec<fir::Attr> {
888890
.filter_map(|attr| match attr {
889891
hir::Attr::EntryPoint => Some(fir::Attr::EntryPoint),
890892
hir::Attr::Measurement => Some(fir::Attr::Measurement),
893+
hir::Attr::Reset => Some(fir::Attr::Reset),
891894
hir::Attr::SimulatableIntrinsic | hir::Attr::Unimplemented | hir::Attr::Config => None,
892895
})
893896
.collect()

compiler/qsc_partial_eval/src/lib.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,9 @@ impl<'a> PartialEvaluator<'a> {
13131313
if matches!(callable_decl.kind, qsc_fir::fir::CallableKind::Measurement) {
13141314
return self.measure_qubits(callable_decl, args_value, args_span);
13151315
}
1316+
if callable_decl.attrs.contains(&fir::Attr::Reset) {
1317+
return self.reset_qubits(store_item_id, callable_decl, args_value);
1318+
}
13161319

13171320
// There are a few special cases regarding intrinsic callables. Identify them and handle them properly.
13181321
match callable_decl.name.name.as_ref() {
@@ -2372,6 +2375,60 @@ impl<'a> PartialEvaluator<'a> {
23722375
result_value
23732376
}
23742377

2378+
fn reset_qubits(
2379+
&mut self,
2380+
store_item_id: StoreItemId,
2381+
callable_decl: &CallableDecl,
2382+
args_value: Value,
2383+
) -> Result<Value, Error> {
2384+
let callable_package = self.package_store.get(store_item_id.package);
2385+
let input_type: Vec<rir::Ty> = callable_package
2386+
.derive_callable_input_params(callable_decl)
2387+
.iter()
2388+
.map(|input_param| map_fir_type_to_rir_type(&input_param.ty))
2389+
.collect();
2390+
let output_type = if callable_decl.output == Ty::UNIT {
2391+
None
2392+
} else {
2393+
panic!("the expressions that make it to this point should return Unit");
2394+
};
2395+
2396+
let measurement_callable = Callable {
2397+
name: callable_decl.name.name.to_string(),
2398+
input_type,
2399+
output_type,
2400+
body: None,
2401+
call_type: CallableType::Reset,
2402+
};
2403+
2404+
// Resolve the call arguments, create the call instruction and insert it to the current block.
2405+
let (args, ctls_arg) = self
2406+
.resolve_args(
2407+
(store_item_id.package, callable_decl.input).into(),
2408+
args_value,
2409+
None,
2410+
None,
2411+
None,
2412+
)
2413+
.expect("no controls to verify");
2414+
assert!(
2415+
ctls_arg.is_none(),
2416+
"intrinsic operations cannot have controls"
2417+
);
2418+
let operands = args
2419+
.into_iter()
2420+
.map(|arg| self.map_eval_value_to_rir_operand(&arg.into_value()))
2421+
.collect();
2422+
2423+
// Check if the callable has already been added to the program and if not do so now.
2424+
let measure_callable_id = self.get_or_insert_callable(measurement_callable);
2425+
let instruction = Instruction::Call(measure_callable_id, operands, None);
2426+
let current_block = self.get_current_rir_block_mut();
2427+
current_block.0.push(instruction);
2428+
2429+
Ok(Value::unit())
2430+
}
2431+
23752432
fn release_qubit(&mut self, args_value: Value) -> Value {
23762433
let qubit = args_value.unwrap_qubit();
23772434
self.resource_manager.release_qubit(qubit);

0 commit comments

Comments
 (0)