Skip to content

Commit 5f15093

Browse files
Use mutable pointers for hint outputs (#3747)
Use `MaybeUninit::as_mut_ptr()` for hint outputs written by `hint_store_u32!` and `hint_buffer_u32!`. This centralizes hint output reads in private helpers so the write-before-assume-init contract is explicit at each buffer boundary.
1 parent a49a700 commit 5f15093

1 file changed

Lines changed: 27 additions & 25 deletions

File tree

  • openvm-riscv/extensions/hints-guest/src

openvm-riscv/extensions/hints-guest/src/lib.rs

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#![no_std]
22
#[cfg(target_os = "zkvm")]
3+
use core::mem::MaybeUninit;
4+
#[cfg(target_os = "zkvm")]
35
use openvm_custom_insn; // needed for the hint_store_u32 macro
46
use strum_macros::FromRepr;
57

@@ -68,16 +70,32 @@ fn insn_k256_sqrt_field_10x26(bytes: *const u8) {
6870
);
6971
}
7072

73+
#[cfg(target_os = "zkvm")]
74+
#[inline(always)]
75+
fn read_hint_u32() -> u32 {
76+
let mut result = MaybeUninit::<u32>::uninit();
77+
unsafe {
78+
openvm_rv32im_guest::hint_store_u32!(result.as_mut_ptr());
79+
result.assume_init()
80+
}
81+
}
82+
83+
#[cfg(target_os = "zkvm")]
84+
#[inline(always)]
85+
fn read_hint_buffer<T, const WORDS: usize>() -> T {
86+
let mut result = MaybeUninit::<T>::uninit();
87+
unsafe {
88+
openvm_rv32im_guest::hint_buffer_u32!(result.as_mut_ptr() as *mut u8, WORDS);
89+
result.assume_init()
90+
}
91+
}
92+
7193
/// Just an example hint that reverses the bytes of a u32 value.
7294
pub fn hint_reverse_bytes(val: u32) -> u32 {
7395
#[cfg(target_os = "zkvm")]
7496
{
75-
let result = core::mem::MaybeUninit::<u32>::uninit();
7697
insn_reverse_bytes(&val as *const u32 as *const u8);
77-
unsafe {
78-
openvm_rv32im_guest::hint_store_u32!(result.as_ptr() as *const u32);
79-
result.assume_init()
80-
}
98+
read_hint_u32()
8199
}
82100
#[cfg(not(target_os = "zkvm"))]
83101
{
@@ -93,11 +111,7 @@ pub fn hint_reverse_bytes(val: u32) -> u32 {
93111
#[cfg(target_os = "zkvm")]
94112
pub fn hint_k256_inverse_field(sec1_bytes: &[u8]) -> [u8; 32] {
95113
insn_k256_inverse_field(sec1_bytes.as_ptr() as *const u8);
96-
let inverse = core::mem::MaybeUninit::<[u8; 32]>::uninit();
97-
unsafe {
98-
openvm_rv32im_guest::hint_buffer_u32!(inverse.as_ptr() as *const u8, 8);
99-
inverse.assume_init()
100-
}
114+
read_hint_buffer::<[u8; 32], 8>()
101115
}
102116

103117
/// Ensures that the 10 limbs are weakly normalized (i.e., the most significant limb is 22 bits and the others are 26 bits).
@@ -126,11 +140,7 @@ fn ensure_weakly_normalized_10x26(limbs: [u32; 10]) -> [u32; 10] {
126140
#[cfg(target_os = "zkvm")]
127141
pub fn hint_k256_inverse_field_10x26(elem: [u32; 10]) -> [u32; 10] {
128142
insn_k256_inverse_field_10x26(elem.as_ptr() as *const u8);
129-
let inverse = core::mem::MaybeUninit::<[u32; 10]>::uninit();
130-
let inverse = unsafe {
131-
openvm_rv32im_guest::hint_buffer_u32!(inverse.as_ptr() as *const u8, 10);
132-
inverse.assume_init()
133-
};
143+
let inverse = read_hint_buffer::<[u32; 10], 10>();
134144
ensure_weakly_normalized_10x26(inverse)
135145
}
136146

@@ -146,17 +156,9 @@ pub const K256_NON_QUADRATIC_RESIDUE: [u32; 10] = [3, 0, 0, 0, 0, 0, 0, 0, 0, 0]
146156
pub fn hint_k256_sqrt_field_10x26(elem: [u32; 10]) -> (bool, [u32; 10]) {
147157
insn_k256_sqrt_field_10x26(elem.as_ptr() as *const u8);
148158
// read the "boolean" result
149-
let has_sqrt = unsafe {
150-
let has_sqrt = core::mem::MaybeUninit::<u32>::uninit();
151-
openvm_rv32im_guest::hint_store_u32!(has_sqrt.as_ptr() as *const u32);
152-
has_sqrt.assume_init() != 0
153-
};
159+
let has_sqrt = read_hint_u32() != 0;
154160
// read the square root value
155-
let sqrt = unsafe {
156-
let sqrt = core::mem::MaybeUninit::<[u32; 10]>::uninit();
157-
openvm_rv32im_guest::hint_buffer_u32!(sqrt.as_ptr() as *const u8, 10);
158-
sqrt.assume_init()
159-
};
161+
let sqrt = read_hint_buffer::<[u32; 10], 10>();
160162
let sqrt = ensure_weakly_normalized_10x26(sqrt);
161163
(has_sqrt, sqrt)
162164
}

0 commit comments

Comments
 (0)