Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better optimization of intermediate polynomials #2532

Open
wants to merge 5 commits into
base: linear-constraint-remover
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 13 additions & 117 deletions pilopt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ pub fn optimize<T: FieldElement>(mut pil_file: Analyzed<T>) -> Analyzed<T> {
extract_constant_lookups(&mut pil_file);
replace_linear_witness_columns(&mut pil_file);
remove_constant_witness_columns(&mut pil_file);
remove_constant_intermediate_columns(&mut pil_file);
simplify_identities(&mut pil_file);
remove_equal_constrained_witness_columns(&mut pil_file);
inline_trivial_intermediate_polynomials(&mut pil_file);
remove_trivial_identities(&mut pil_file);
remove_duplicate_identities(&mut pil_file);

Expand Down Expand Up @@ -724,25 +723,28 @@ fn remove_constant_witness_columns<T: FieldElement>(pil_file: &mut Analyzed<T>)
substitute_polynomial_references(pil_file, constant_polys);
}

/// Identifies intermediate columns that are constrained to a single value, replaces every
/// reference to this column by the value and deletes the column.
fn remove_constant_intermediate_columns<T: FieldElement>(pil_file: &mut Analyzed<T>) {
/// Inlines `col i = e` into the references to `i` where `e` is an expression with no operations.
/// The reasoning is that intermediate columns are useful to remember intermediate computation results, but in this case
/// the intermediate results are already known.
fn inline_trivial_intermediate_polynomials<T: FieldElement>(pil_file: &mut Analyzed<T>) {
let intermediate_polys = pil_file
.intermediate_polys_in_source_order()
.filter_map(|(symbol, definitions)| {
let mut symbols_and_definitions = symbol.array_elements().zip_eq(definitions);
match symbol.is_array() {
true => None,
false => {
let ((name, poly_id), definition) = symbols_and_definitions.next().unwrap();
match definition {
AlgebraicExpression::Number(value) => {
let ((name, poly_id), value) = symbols_and_definitions.next().unwrap();
match value {
AlgebraicExpression::BinaryOperation(_) | AlgebraicExpression::UnaryOperation(_) => {
None
}
_ =>{
log::debug!(
"Determined intermediate column {name} to be constant {value}. Removing.",
"Determined intermediate column {name} to be trivial value `{value}`. Removing.",
);
Some(((name.clone(), poly_id), (*value).into()))
Some(((name, poly_id), value.clone()))
}
_ => None,
}
}
}
Expand Down Expand Up @@ -1081,47 +1083,6 @@ fn remove_duplicate_identities<T: FieldElement>(pil_file: &mut Analyzed<T>) {
pil_file.remove_identities(&to_remove);
}

/// Identifies witness columns that are directly constrained to be equal to other witness columns
/// through polynomial identities of the form "x = y" and returns a tuple ((name, id), (name, id))
/// for each pair of identified columns
fn equal_constrained<T: FieldElement>(
expression: &AlgebraicExpression<T>,
poly_id_to_array_elem: &BTreeMap<PolyID, (&String, Option<usize>)>,
) -> Option<((String, PolyID), (String, PolyID))> {
match expression {
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation {
left,
op: AlgebraicBinaryOperator::Sub,
right,
}) => match (left.as_ref(), right.as_ref()) {
(AlgebraicExpression::Reference(l), AlgebraicExpression::Reference(r)) => {
let is_valid = |x: &AlgebraicReference, left: bool| {
x.is_witness()
&& if left {
// We don't allow the left-hand side to be an array element
// to preserve array integrity (e.g. `x = y` is valid, but `x[0] = y` is not)
poly_id_to_array_elem.get(&x.poly_id).unwrap().1.is_none()
} else {
true
}
};

if is_valid(l, true) && is_valid(r, false) && r.next == l.next {
Some(if l.poly_id > r.poly_id {
((l.name.clone(), l.poly_id), (r.name.clone(), r.poly_id))
} else {
((r.name.clone(), r.poly_id), (l.name.clone(), l.poly_id))
})
} else {
None
}
}
_ => None,
},
_ => None,
}
}

/// Tries to extract a boolean constrained witness column from a polynomial identity.
/// The pattern used is `x * (1 - x) = 0` or `(1 - x) * x = 0` where `x` is a witness column.
fn try_to_boolean_constrained<T: FieldElement>(id: &Identity<T>) -> Option<PolyID> {
Expand Down Expand Up @@ -1180,68 +1141,3 @@ fn try_to_boolean_constrained<T: FieldElement>(id: &Identity<T>) -> Option<PolyI
None
}
}

fn remove_equal_constrained_witness_columns<T: FieldElement>(pil_file: &mut Analyzed<T>) {
let poly_id_to_array_elem = build_poly_id_to_definition_name_lookup(pil_file);
let mut substitutions: BTreeMap<(String, PolyID), (String, PolyID)> = pil_file
.identities
.iter()
.filter_map(|id| {
if let Identity::Polynomial(PolynomialIdentity { expression, .. }) = id {
equal_constrained(expression, &poly_id_to_array_elem)
} else {
None
}
})
.collect();

resolve_transitive_substitutions(&mut substitutions);

let (subs_by_id, subs_by_name): (HashMap<_, _>, HashMap<_, _>) = substitutions
.iter()
.map(|(k, v)| ((k.1, v), (&k.0, v)))
.unzip();

pil_file.post_visit_expressions_in_identities_mut(&mut |e: &mut AlgebraicExpression<_>| {
if let AlgebraicExpression::Reference(ref mut reference) = e {
if let Some((replacement_name, replacement_id)) = subs_by_id.get(&reference.poly_id) {
reference.poly_id = *replacement_id;
reference.name = replacement_name.clone();
}
}
});

pil_file.post_visit_expressions_mut(&mut |e: &mut Expression| {
if let Expression::Reference(_, Reference::Poly(reference)) = e {
if let Some((replacement_name, _)) = subs_by_name.get(&reference.name) {
reference.name = replacement_name.clone();
}
}
});
}

fn resolve_transitive_substitutions(subs: &mut BTreeMap<(String, PolyID), (String, PolyID)>) {
let mut changed = true;
while changed {
changed = false;
let keys: Vec<_> = subs
.keys()
.map(|(name, id)| (name.to_string(), *id))
.collect();

for key in keys {
let Some(target_key) = subs.get(&key) else {
continue;
};

let Some(new_target) = subs.get(target_key) else {
continue;
};

if subs.get(&key).unwrap() != new_target {
subs.insert(key, new_target.clone());
changed = true;
}
}
}
}
14 changes: 6 additions & 8 deletions pilopt/tests/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ fn replace_fixed() {
one * Y = zero * Y + 7 * X * X;
"#;
let expectation = r#"namespace N(65536);
col witness X;
col witness Y;
query |i| {
let _: expr = 1_expr;
};
N::X = 7 * N::X * N::X;
N::Y = 7 * N::Y * N::Y;
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
Expand Down Expand Up @@ -120,10 +120,8 @@ fn intermediate() {
"#;
let expectation = r#"namespace N(65536);
col witness x;
col intermediate = N::x;
col int2 = N::intermediate * N::x;
col int3 = N::int2;
N::int3 = 3 * N::x + N::x;
col int2 = N::x * N::x;
N::int2 = 3 * N::x + N::x;
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
Expand Down Expand Up @@ -459,8 +457,8 @@ fn equal_constrained_transitive() {
a + b + c = 5;
"#;
let expectation = r#"namespace N(65536);
col witness a;
N::a + N::a + N::a = 5;
col witness c;
N::c + N::c + N::c = 5;
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
Expand Down
Loading