Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Optimize zktrieState with flatten proofs #1388

Merged
merged 8 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

187 changes: 149 additions & 38 deletions bus-mapping/src/circuit_input_builder/l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,35 @@ impl CircuitInputBuilder {
hex::encode(old_root),
);

let mpt_init_state = if !light_mode {
let mpt_init_state = if !l2_trace.storage_trace.flatten_proofs.is_empty() {
log::info!("always init mpt state with flatten proofs");
let mut state = ZktrieState::construct(old_root);
let zk_db = state.expose_db();
for (k, bytes) in &l2_trace.storage_trace.flatten_proofs {
zk_db.add_node_bytes(bytes, Some(k.as_bytes())).unwrap();
}
zk_db.with_key_cache(
l2_trace
.storage_trace
.address_hashes
.iter()
.map(|(k, v)| (k.as_bytes(), v.as_bytes())),
);
zk_db.with_key_cache(
l2_trace
.storage_trace
.store_key_hashes
.iter()
.map(|(k, v)| (k.as_bytes(), v.as_bytes())),
);

log::debug!(
"building partial ZktrieState done from new trace, root {}",
hex::encode(state.root())
);

Some(state)
} else if !light_mode {
let mpt_init_state = ZktrieState::from_trace_with_additional(
old_root,
Self::collect_account_proofs(&l2_trace.storage_trace),
Expand All @@ -148,21 +176,58 @@ impl CircuitInputBuilder {
};

let mut sdb = StateDB::new();
for parsed in ZktrieState::parse_account_from_proofs(Self::collect_account_proofs(
&l2_trace.storage_trace,
)) {
let (addr, acc) = parsed.map_err(Error::IoError)?;
log::trace!("sdb trace {:?} {:?}", addr, acc);
sdb.set_account(&addr, state_db::Account::from(&acc));
}
if let Some(zk_state) = &mpt_init_state {
for (addr, acc) in zk_state.query_accounts(
Self::collect_account_proofs(&l2_trace.storage_trace).map(|(addr, _)| addr),
) {
if let Some(acc) = acc {
log::trace!("sdb trace[query mode] {:?} {:?}", addr, acc);
sdb.set_account(&addr, state_db::Account::from(&acc));
} else {
log::trace!("sdb trace[query mode] {:?} for zero account", addr);
sdb.set_account(&addr, state_db::Account::zero());
}
}

for ((addr, key), val) in zk_state.query_storages(
Self::collect_storage_proofs(&l2_trace.storage_trace)
.map(|(addr, key, _)| (addr, key)),
) {
let key = key.to_word();
if let Some(val) = val {
log::trace!(
"sdb trace storage[query mode] {:?} {:?} {:?}",
addr,
key,
val
);
*sdb.get_storage_mut(&addr, &key).1 = val.into();
} else {
log::trace!(
"sdb trace storage[query mode] {:?} {:?} for zero",
addr,
key
);
*sdb.get_storage_mut(&addr, &key).1 = Default::default();
}
}
} else {
for parsed in ZktrieState::parse_account_from_proofs(Self::collect_account_proofs(
&l2_trace.storage_trace,
)) {
let (addr, acc) = parsed.map_err(Error::IoError)?;
log::trace!("sdb trace {:?} {:?}", addr, acc);
sdb.set_account(&addr, state_db::Account::from(&acc));
}

for parsed in ZktrieState::parse_storage_from_proofs(Self::collect_storage_proofs(
&l2_trace.storage_trace,
)) {
let ((addr, key), val) = parsed.map_err(Error::IoError)?;
let key = key.to_word();
log::trace!("sdb trace storage {:?} {:?} {:?}", addr, key, val);
*sdb.get_storage_mut(&addr, &key).1 = val.into();
for parsed in ZktrieState::parse_storage_from_proofs(Self::collect_storage_proofs(
&l2_trace.storage_trace,
)) {
let ((addr, key), val) = parsed.map_err(Error::IoError)?;
let key = key.to_word();
log::trace!("sdb trace storage {:?} {:?} {:?}", addr, key, val);
*sdb.get_storage_mut(&addr, &key).1 = val.into();
}
}

let mut code_db = CodeDB::new();
Expand Down Expand Up @@ -191,7 +256,31 @@ impl CircuitInputBuilder {
/// Apply more l2 traces
pub fn add_more_l2_trace(&mut self, l2_trace: BlockTrace) -> Result<(), Error> {
// update init state new data from storage
if let Some(mpt_init_state) = &mut self.mpt_init_state {
if !l2_trace.storage_trace.flatten_proofs.is_empty() {
let mpt_state = self
.mpt_init_state
.as_mut()
.expect("should have inited with flatten proof");
log::info!("add more flatten proofs to mpt state");
let zk_db = mpt_state.expose_db();
for (k, bytes) in &l2_trace.storage_trace.flatten_proofs {
zk_db.add_node_bytes(bytes, Some(k.as_bytes())).unwrap();
}
zk_db.with_key_cache(
l2_trace
.storage_trace
.address_hashes
.iter()
.map(|(k, v)| (k.as_bytes(), v.as_bytes())),
);
zk_db.with_key_cache(
l2_trace
.storage_trace
.store_key_hashes
.iter()
.map(|(k, v)| (k.as_bytes(), v.as_bytes())),
);
} else if let Some(mpt_init_state) = &mut self.mpt_init_state {
mpt_init_state.update_from_trace(
Self::collect_account_proofs(&l2_trace.storage_trace),
Self::collect_storage_proofs(&l2_trace.storage_trace),
Expand All @@ -203,40 +292,62 @@ impl CircuitInputBuilder {
);
}

let new_accounts = ZktrieState::parse_account_from_proofs(
let filtered_accounts =
Self::collect_account_proofs(&l2_trace.storage_trace).filter(|(addr, _)| {
let (existed, _) = self.sdb.get_account(addr);
!existed
}),
)
.try_fold(
HashMap::new(),
|mut m, parsed| -> Result<HashMap<_, _>, Error> {
let (addr, acc) = parsed.map_err(Error::IoError)?;
m.insert(addr, acc);
Ok(m)
},
)?;
});

let new_accounts = if let Some(zk_state) = &self.mpt_init_state {
zk_state
.query_accounts(filtered_accounts.map(|(addr, _)| addr))
.fold(HashMap::new(), |mut m, (addr, acc)| {
m.insert(addr, acc.unwrap_or_default());
m
})
} else {
ZktrieState::parse_account_from_proofs(filtered_accounts).try_fold(
HashMap::new(),
|mut m, parsed| -> Result<HashMap<_, _>, Error> {
let (addr, acc) = parsed.map_err(Error::IoError)?;
m.insert(addr, acc);
Ok(m)
},
)?
};

for (addr, acc) in new_accounts {
self.sdb.set_account(&addr, state_db::Account::from(&acc));
}

let new_storages = ZktrieState::parse_storage_from_proofs(
let filtered_storages =
Self::collect_storage_proofs(&l2_trace.storage_trace).filter(|(addr, key, _)| {
let key = key.to_word();
let (existed, _) = self.sdb.get_committed_storage(addr, &key);
!existed
}),
)
.try_fold(
HashMap::new(),
|mut m, parsed| -> Result<HashMap<(Address, Word), Word>, Error> {
let ((addr, key), val) = parsed.map_err(Error::IoError)?;
m.insert((addr, key.to_word()), val.into());
Ok(m)
},
)?;
});

let new_storages = if let Some(zk_state) = &self.mpt_init_state {
zk_state
.query_storages(filtered_storages.map(|(addr, key, _)| (addr, key)))
.fold(HashMap::new(), |mut m, ((addr, key), val)| {
if let Some(val) = val {
m.insert((addr, key.to_word()), val.into());
} else {
m.insert((addr, key.to_word()), Default::default());
}
m
})
} else {
ZktrieState::parse_storage_from_proofs(filtered_storages).try_fold(
HashMap::new(),
|mut m, parsed| -> Result<HashMap<(Address, Word), Word>, Error> {
let ((addr, key), val) = parsed.map_err(Error::IoError)?;
m.insert((addr, key.to_word()), val.into());
Ok(m)
},
)?
};

for ((addr, key), val) in new_storages {
*self.sdb.get_storage_mut(&addr, &key).1 = val;
Expand Down
9 changes: 9 additions & 0 deletions eth-types/src/l2_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,15 @@ pub struct StorageTrace {
#[serde(rename = "deletionProofs", default)]
/// additional deletion proofs
pub deletion_proofs: Vec<Bytes>,
#[serde(rename = "flattenProofs", default)]
///
pub flatten_proofs: HashMap<H256, Bytes>,
#[serde(rename = "addressHashes", default)]
///
pub address_hashes: HashMap<Address, Hash>,
#[serde(rename = "storeKeyHashes", default)]
///
pub store_key_hashes: HashMap<H256, Hash>,
}

/// extension of `GethExecTrace`, with compatible serialize form
Expand Down
51 changes: 51 additions & 0 deletions integration-tests/tests/l2_trace.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#![feature(lazy_cell)]
#![cfg(feature = "scroll")]

use bus_mapping::{
circuit_input_builder::{CircuitInputBuilder, CircuitsParams},
util::read_env_var,
};
use eth_types::l2_types::BlockTrace;
use integration_tests::log_init;
use std::fs::File;
use zkevm_circuits::witness;

fn test_circuit_input_builder_l2block(block_trace: BlockTrace) {
let params = CircuitsParams {
max_rws: 4_000_000,
max_copy_rows: 0, // dynamic
max_txs: read_env_var("MAX_TXS", 128),
max_calldata: 2_000_000,
max_inner_blocks: 64,
max_bytecode: 3_000_000,
max_mpt_rows: 2_000_000,
max_poseidon_rows: 4_000_000,
max_keccak_rows: 0,
max_exp_steps: 100_000,
max_evm_rows: 0,
max_rlp_rows: 2_070_000,
..Default::default()
};

let mut builder = CircuitInputBuilder::new_from_l2_trace(params, block_trace, false)
.expect("could not handle block tx");

builder
.finalize_building()
.expect("could not finalize building block");

log::trace!("CircuitInputBuilder: {:#?}", builder);

let mut block = witness::block_convert(&builder.block, &builder.code_db).unwrap();
block.apply_mpt_updates(&builder.mpt_init_state.unwrap());
}

#[test]
fn local_l2_trace() {
log_init();
let file_path = read_env_var("TRACE_FILE", "dump.json".to_string());
let fd = File::open(file_path).unwrap();
let trace: BlockTrace = serde_json::from_reader(fd).unwrap();

test_circuit_input_builder_l2block(trace);
}
2 changes: 1 addition & 1 deletion prover/src/zkevm/circuit/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ pub fn finalize_builder(builder: &mut CircuitInputBuilder) -> Result<Block> {
if let Some(state) = &mut builder.mpt_init_state {
if *state.root() != [0u8; 32] {
log::debug!("apply_mpt_updates");
witness_block.apply_mpt_updates(state);
witness_block.apply_mpt_updates_and_update_mpt_state(state);
log::debug!("apply_mpt_updates done");
} else {
// Empty state root means circuit capacity checking, or dummy witness block for key gen?
Expand Down
11 changes: 11 additions & 0 deletions zkevm-circuits/src/witness/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ impl Block {
pub fn apply_mpt_updates(&mut self, mpt_state: &MptState) {
self.mpt_updates.fill_state_roots(mpt_state);
}

/// Replay mpt updates to generate mpt witness, also update the mpt state with
/// calculated mpt updatings
pub fn apply_mpt_updates_and_update_mpt_state(&mut self, mpt_state: &mut MptState) {
let updated_tries = self
.mpt_updates
.fill_state_roots(mpt_state)
.into_updated_trie();
mpt_state.updated_with_trie(updated_tries);
}

/// For each tx, for each step, print the rwc at the beginning of the step,
/// and all the rw operations of the step.
pub(crate) fn debug_print_txs_steps_rw_ops(&self) {
Expand Down
3 changes: 2 additions & 1 deletion zkevm-circuits/src/witness/mpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl MptUpdates {
self.pretty_print();
}

pub(crate) fn fill_state_roots(&mut self, init_trie: &ZktrieState) {
pub(crate) fn fill_state_roots(&mut self, init_trie: &ZktrieState) -> WitnessGenerator {
let root_pair = (self.old_root, self.new_root);
self.old_root = init_trie.root().into();
log::trace!("fill_state_roots init {:?}", self.old_root);
Expand Down Expand Up @@ -223,6 +223,7 @@ impl MptUpdates {
}
log::debug!("fill_state_roots done");
self.pretty_print();
wit_gen
}

fn fill_state_roots_from_generator(
Expand Down
Loading
Loading