Skip to content

Commit ce10fe7

Browse files
authored
Refactor public declaration and public reference (#2516)
Started as a small fix but expanded to a complete refactoring of `PublicDeclaration` and `PublicReference`. Now public reference work the same way as any reference and the syntax is without the colon (`:out` -> `out`). `PublicDeclaration` are also put under definitions in the analyzer. Both public reference and public declaration use the absolute path, so it's easy to refer from one to the other now. This is needed for #2502 to match public reference to public declaration in the backend. Currently, different namespaces can have public references with the same name, and it will confuse the backend on which public reference value to fetch.
1 parent bb86aef commit ce10fe7

File tree

28 files changed

+308
-219
lines changed

28 files changed

+308
-219
lines changed

asm-to-pil/src/vm_to_constrained.rs

-1
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,6 @@ impl<T: FieldElement> VMConverter<T> {
741741

742742
fn process_assignment_value(&self, value: Expression) -> Vec<(T, AffineExpressionComponent)> {
743743
match value {
744-
Expression::PublicReference(_, _) => panic!(),
745744
Expression::IndexAccess(_, _) => panic!(),
746745
Expression::FunctionCall(_, _) => panic!(),
747746
Expression::Reference(_, reference) => {

asmopt/src/lib.rs

-1
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ fn expr_to_ref(expr: &Expression) -> Option<String> {
204204
Expression::Reference(_, NamespacedPolynomialReference { path, .. }) => {
205205
Some(path.to_string())
206206
}
207-
Expression::PublicReference(_, pref) => Some(pref.clone()),
208207
_ => None,
209208
}
210209
}

ast/src/analyzed/display.rs

+13-15
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ impl<T: Display> Display for Analyzed<T> {
9595
format_witness_column(&name, symbol, definition),
9696
)?,
9797
SymbolKind::Poly(PolynomialType::Intermediate) => unreachable!(),
98+
SymbolKind::Public() => {
99+
if let Some(FunctionValueDefinition::PublicDeclaration(decl)) =
100+
definition
101+
{
102+
writeln_indented(f, format_public_declaration(&name, decl))?;
103+
} else {
104+
unreachable!() // For now, public symbol should always have a public declaration
105+
}
106+
}
98107
SymbolKind::Other() => {
99108
assert!(symbol.stage.is_none());
100109
match definition {
@@ -159,11 +168,6 @@ impl<T: Display> Display for Analyzed<T> {
159168
panic!()
160169
}
161170
}
162-
StatementIdentifier::PublicDeclaration(name) => {
163-
let decl = &self.public_declarations[name];
164-
let name = update_namespace(&decl.name, f)?;
165-
writeln_indented(f, format_public_declaration(&name, decl))?;
166-
}
167171
StatementIdentifier::ProofItem(i) => {
168172
writeln_indented(f, &self.identities[*i])?;
169173
}
@@ -243,14 +247,7 @@ fn format_witness_column(
243247
}
244248

245249
fn format_public_declaration(name: &str, decl: &PublicDeclaration) -> String {
246-
format!(
247-
"public {name} = {}{}({});",
248-
decl.polynomial,
249-
decl.array_index
250-
.map(|i| format!("[{i}]"))
251-
.unwrap_or_default(),
252-
decl.index
253-
)
250+
format!("public {name} = {};", decl.value)
254251
}
255252

256253
impl Display for FunctionValueDefinition {
@@ -277,7 +274,8 @@ impl Display for FunctionValueDefinition {
277274
FunctionValueDefinition::TypeDeclaration(_)
278275
| FunctionValueDefinition::TypeConstructor(_, _)
279276
| FunctionValueDefinition::TraitDeclaration(_)
280-
| FunctionValueDefinition::TraitFunction(_, _) => {
277+
| FunctionValueDefinition::TraitFunction(_, _)
278+
| FunctionValueDefinition::PublicDeclaration(_) => {
281279
panic!("Should not use this formatting function.")
282280
}
283281
}
@@ -484,7 +482,7 @@ impl<T: Display> Display for AlgebraicExpression<T> {
484482
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
485483
match self {
486484
AlgebraicExpression::Reference(reference) => write!(f, "{reference}"),
487-
AlgebraicExpression::PublicReference(name) => write!(f, ":{name}"),
485+
AlgebraicExpression::PublicReference(name) => write!(f, "{name}"),
488486
AlgebraicExpression::Challenge(challenge) => {
489487
write!(
490488
f,

ast/src/analyzed/mod.rs

+90-28
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,15 @@ use crate::parsed::visitor::{Children, ExpressionVisitable};
2222
pub use crate::parsed::BinaryOperator;
2323
pub use crate::parsed::UnaryOperator;
2424
use crate::parsed::{
25-
self, ArrayExpression, EnumDeclaration, EnumVariant, NamedType, SourceReference,
26-
TraitDeclaration, TraitImplementation, TypeDeclaration,
25+
self, ArrayExpression, EnumDeclaration, EnumVariant, FunctionCall, IndexAccess, NamedType,
26+
SourceReference, TraitDeclaration, TraitImplementation, TypeDeclaration,
2727
};
2828
pub use contains_next_ref::ContainsNextRef;
2929

3030
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
3131
pub enum StatementIdentifier {
3232
/// Either an intermediate column or a definition.
3333
Definition(String),
34-
PublicDeclaration(String),
3534
/// Index into the vector of proof items / identities.
3635
ProofItem(usize),
3736
/// Index into the vector of prover functions.
@@ -44,7 +43,6 @@ pub enum StatementIdentifier {
4443
pub struct Analyzed<T> {
4544
pub definitions: HashMap<String, (Symbol, Option<FunctionValueDefinition>)>,
4645
pub solved_impls: SolvedTraitImpls,
47-
pub public_declarations: HashMap<String, PublicDeclaration>,
4846
pub intermediate_columns: HashMap<String, (Symbol, Vec<AlgebraicExpression<T>>)>,
4947
pub identities: Vec<Identity<T>>,
5048
pub prover_functions: Vec<Expression>,
@@ -122,7 +120,10 @@ impl<T> Analyzed<T> {
122120
}
123121
/// @returns the number of public inputs
124122
pub fn publics_count(&self) -> usize {
125-
self.public_declarations.len()
123+
self.definitions
124+
.iter()
125+
.filter(|(_name, (symbol, _))| matches!(symbol.kind, SymbolKind::Public()))
126+
.count()
126127
}
127128

128129
pub fn name_to_poly_id(&self) -> BTreeMap<String, PolyID> {
@@ -187,8 +188,12 @@ impl<T> Analyzed<T> {
187188
&self,
188189
) -> impl Iterator<Item = (&String, &PublicDeclaration)> {
189190
self.source_order.iter().filter_map(move |statement| {
190-
if let StatementIdentifier::PublicDeclaration(name) = statement {
191-
if let Some(public_declaration) = self.public_declarations.get(name) {
191+
if let StatementIdentifier::Definition(name) = statement {
192+
if let Some((
193+
_,
194+
Some(FunctionValueDefinition::PublicDeclaration(public_declaration)),
195+
)) = self.definitions.get(name)
196+
{
192197
return Some((name, public_declaration));
193198
}
194199
}
@@ -398,29 +403,26 @@ impl<T> Analyzed<T> {
398403
/// Retrieves (name, col_name, poly_id, offset, stage) of each public witness in the trace.
399404
pub fn get_publics(&self) -> Vec<(String, String, PolyID, usize, u8)> {
400405
let mut publics = self
401-
.public_declarations
402-
.values()
403-
.map(|public_declaration| {
406+
.public_declarations_in_source_order()
407+
.map(|(name, public_declaration)| {
404408
let column_name = public_declaration.referenced_poly_name();
405409
let (poly_id, stage) = {
406-
let symbol = &self.definitions[&public_declaration.polynomial.name].0;
410+
let symbol = &self.definitions[&public_declaration.referenced_poly().name].0;
407411
(
408412
symbol
409413
.array_elements()
410-
.nth(public_declaration.array_index.unwrap_or_default())
414+
.nth(
415+
public_declaration
416+
.referenced_poly_array_index()
417+
.unwrap_or_default(),
418+
)
411419
.unwrap()
412420
.1,
413421
symbol.stage.unwrap_or_default() as u8,
414422
)
415423
};
416-
let row_offset = public_declaration.index as usize;
417-
(
418-
public_declaration.name.clone(),
419-
column_name,
420-
poly_id,
421-
row_offset,
422-
stage,
423-
)
424+
let row_offset = public_declaration.row() as usize;
425+
(name.clone(), column_name, poly_id, row_offset, stage)
424426
})
425427
.collect::<Vec<_>>();
426428

@@ -559,6 +561,7 @@ pub fn type_from_definition(
559561
ty: trait_func.ty.clone(),
560562
})
561563
}
564+
FunctionValueDefinition::PublicDeclaration(_) => Some(Type::Expr.into()),
562565
}
563566
} else {
564567
assert!(
@@ -771,6 +774,8 @@ impl Symbol {
771774
pub enum SymbolKind {
772775
/// Fixed, witness or intermediate polynomial
773776
Poly(PolynomialType),
777+
/// Public declaration
778+
Public(),
774779
/// Other symbol, depends on the type.
775780
/// Examples include functions not of the type "int -> fe".
776781
Other(),
@@ -784,6 +789,7 @@ pub enum FunctionValueDefinition {
784789
TypeConstructor(Arc<EnumDeclaration>, EnumVariant),
785790
TraitDeclaration(TraitDeclaration),
786791
TraitFunction(Arc<TraitDeclaration>, NamedType),
792+
PublicDeclaration(PublicDeclaration),
787793
}
788794

789795
impl Children<Expression> for FunctionValueDefinition {
@@ -799,6 +805,7 @@ impl Children<Expression> for FunctionValueDefinition {
799805
FunctionValueDefinition::TypeConstructor(_, variant) => variant.children(),
800806
FunctionValueDefinition::TraitDeclaration(trait_decl) => trait_decl.children(),
801807
FunctionValueDefinition::TraitFunction(_, trait_func) => trait_func.children(),
808+
FunctionValueDefinition::PublicDeclaration(pub_decl) => pub_decl.children(),
802809
}
803810
}
804811

@@ -814,6 +821,7 @@ impl Children<Expression> for FunctionValueDefinition {
814821
FunctionValueDefinition::TypeConstructor(_, variant) => variant.children_mut(),
815822
FunctionValueDefinition::TraitDeclaration(trait_decl) => trait_decl.children_mut(),
816823
FunctionValueDefinition::TraitFunction(_, trait_func) => trait_func.children_mut(),
824+
FunctionValueDefinition::PublicDeclaration(pub_decl) => pub_decl.children_mut(),
817825
}
818826
}
819827
}
@@ -836,26 +844,80 @@ impl Children<Expression> for NamedType {
836844
}
837845
}
838846

839-
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
847+
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash, PartialEq, Eq)]
840848
pub struct PublicDeclaration {
841849
pub id: u64,
842850
pub source: SourceRef,
843851
pub name: String,
844-
pub polynomial: PolynomialReference,
845-
pub array_index: Option<usize>,
846-
/// The evaluation point of the polynomial, not the array index.
847-
pub index: DegreeType,
852+
/// The declaration value, in two possible forms: polynomial[array_index](row) OR polynomial(row)
853+
/// where "row" is the evaluation point of the polynomial.
854+
pub value: Expression,
848855
}
849856

850857
impl PublicDeclaration {
858+
pub fn referenced_poly(&self) -> &PolynomialReference {
859+
match &self.value {
860+
Expression::FunctionCall(_, FunctionCall { function, .. }) => match function.as_ref() {
861+
Expression::Reference(_, Reference::Poly(poly)) => poly,
862+
Expression::IndexAccess(_, IndexAccess { array, .. }) => match array.as_ref() {
863+
Expression::Reference(_, Reference::Poly(poly)) => poly,
864+
_ => panic!("Expected Reference."),
865+
},
866+
_ => panic!("Expected Reference or IndexAccess."),
867+
},
868+
_ => panic!("Expected FunctionCall."),
869+
}
870+
}
871+
872+
pub fn referenced_poly_array_index(&self) -> Option<usize> {
873+
match &self.value {
874+
Expression::FunctionCall(_, FunctionCall { function, .. }) => match function.as_ref() {
875+
Expression::Reference(_, Reference::Poly(_)) => None,
876+
Expression::IndexAccess(_, IndexAccess { index, .. }) => match index.as_ref() {
877+
Expression::Number(_, index) => Some(index.value.clone().try_into().unwrap()),
878+
_ => panic!("Expected Number."),
879+
},
880+
_ => panic!("Expected Reference or IndexAccess."),
881+
},
882+
_ => panic!("Expected FunctionCall."),
883+
}
884+
}
885+
886+
/// Returns the name of the polynomial referenced by the public declaration.
887+
/// Includes the array index if present.
851888
pub fn referenced_poly_name(&self) -> String {
852-
match self.array_index {
853-
Some(index) => format!("{}[{}]", self.polynomial.name, index),
854-
None => self.polynomial.name.clone(),
889+
let name = self.referenced_poly().name.clone();
890+
let index = self.referenced_poly_array_index();
891+
if let Some(index) = index {
892+
format!("{name}[{index}]")
893+
} else {
894+
name
895+
}
896+
}
897+
898+
pub fn row(&self) -> DegreeType {
899+
match &self.value {
900+
Expression::FunctionCall(_, FunctionCall { arguments, .. }) => {
901+
assert!(arguments.len() == 1);
902+
match &arguments[0] {
903+
Expression::Number(_, index) => index.value.clone().try_into().unwrap(),
904+
_ => panic!("Expected Number."),
905+
}
906+
}
907+
_ => panic!("Expected FunctionCall."),
855908
}
856909
}
857910
}
858911

912+
impl Children<Expression> for PublicDeclaration {
913+
fn children(&self) -> Box<dyn Iterator<Item = &Expression> + '_> {
914+
Box::new(self.value.children())
915+
}
916+
fn children_mut(&mut self) -> Box<dyn Iterator<Item = &mut Expression> + '_> {
917+
Box::new(self.value.children_mut())
918+
}
919+
}
920+
859921
#[derive(
860922
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
861923
)]

ast/src/asm_analysis/display.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ impl Display for LinkDefinition {
120120
write!(
121121
f,
122122
"link {}{} {};",
123-
if flag == 1.into() {
123+
if flag == 1u32.into() {
124124
"".to_string()
125125
} else {
126126
format!("if {flag} ")

ast/src/parsed/display.rs

+12-3
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ impl Display for LinkDeclaration {
172172
write!(
173173
f,
174174
"link {}{} {}",
175-
if self.flag == 1.into() {
175+
if self.flag == 1u32.into() {
176176
"".to_string()
177177
} else {
178178
format!("if {} ", self.flag)
@@ -555,7 +555,17 @@ impl Display for FunctionDefinition {
555555
)
556556
}
557557
FunctionDefinition::Expression(e) => write!(f, " = {e}"),
558-
FunctionDefinition::TypeDeclaration(_) | FunctionDefinition::TraitDeclaration(_) => {
558+
FunctionDefinition::PublicDeclaration(poly, array_index, index) => {
559+
write!(
560+
f,
561+
" = {poly}{}({index});",
562+
array_index
563+
.as_ref()
564+
.map(|i| format!("[{i}]"))
565+
.unwrap_or_default()
566+
)
567+
}
568+
FunctionDefinition::TraitDeclaration(_) | FunctionDefinition::TypeDeclaration(_) => {
559569
panic!("Should not use this formatting function.")
560570
}
561571
}
@@ -717,7 +727,6 @@ impl<Ref: Display> Display for Expression<Ref> {
717727
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
718728
match self {
719729
Expression::Reference(_, reference) => write!(f, "{reference}"),
720-
Expression::PublicReference(_, name) => write!(f, ":{name}"),
721730
Expression::Number(_, n) => write!(f, "{n}"),
722731
Expression::String(_, value) => write!(f, "{}", quote(value)),
723732
Expression::Tuple(_, items) => write!(f, "({})", format_list(items)),

0 commit comments

Comments
 (0)