diff --git a/src/action/insert_many.rs b/src/action/insert_many.rs index a03c57a6d..eaea20dfb 100644 --- a/src/action/insert_many.rs +++ b/src/action/insert_many.rs @@ -5,7 +5,7 @@ use serde::Serialize; use crate::{ coll::options::InsertManyOptions, - error::{Error, ErrorKind, IndexedWriteError, InsertManyError, Result}, + error::{Error, ErrorKind, InsertManyError, Result}, operation::Insert as Op, options::WriteConcern, results::InsertManyResult, @@ -84,6 +84,97 @@ impl<'a> InsertMany<'a> { } } +struct Cumulative { + /// number of document results processed in prior batches + offset: usize, + /// if only `inserted_ids` is populated, does not actually represent an error + maybe_failure: InsertManyError, + labels: HashSet, +} +impl Cumulative { + fn new() -> Self { + Self { + offset: Default::default(), + maybe_failure: InsertManyError { + write_errors: Default::default(), + write_concern_error: Default::default(), + inserted_ids: Default::default(), + }, + labels: Default::default(), + } + } + fn labels>(&mut self, other: impl IntoIterator) { + self.labels.extend(other.into_iter().map(Into::into)) + } + fn result(self) -> Result { + // destructuring here so compilation fails if new fields are added + // unused variable warnings would indicate non-exhaustive checking + let Cumulative { + offset: _, + maybe_failure: + InsertManyError { + write_errors, + write_concern_error, + inserted_ids, + }, + labels, + } = self; + if write_errors.as_ref().map_or(false, |we| !we.is_empty()) + || write_concern_error.is_some() + || !labels.is_empty() + { + Err(Error::new( + ErrorKind::InsertMany(InsertManyError { + write_errors, + write_concern_error, + inserted_ids, + }), + (!labels.is_empty()).then_some(labels), + )) + } else { + Ok(InsertManyResult { inserted_ids }) + } + } +} +impl std::ops::AddAssign for Cumulative { + fn add_assign(&mut self, other: InsertManyResult) { + self.maybe_failure.inserted_ids.extend( + other + .inserted_ids + .into_iter() + .map(|(id, v)| (id + self.offset, v)), + ) + } +} +impl std::ops::AddAssign for Cumulative { + fn add_assign(&mut self, other: InsertManyError) { + let this = &mut self.maybe_failure; + this.inserted_ids.extend( + other + .inserted_ids + .into_iter() + .map(|(id, v)| (id + self.offset, v)), + ); + if let Some(mut other) = other.write_errors { + if let Some(this) = &mut this.write_errors { + this.extend(other.into_iter().map(|mut x| { + x.index += self.offset; + x + })) + } else { + other.iter_mut().for_each(|x| x.index += self.offset); + this.write_errors = Some(other) + } + } + if other.write_concern_error.is_some() { + // technically the left error gets overwritten, but we should never get here + debug_assert!(this.write_concern_error.is_none()); + this.write_concern_error = other.write_concern_error + }; + self.offset = this.inserted_ids.len() + this.write_errors.as_ref().map_or(0, |we| we.len()); + } +} + #[action_impl] impl<'a> Action for InsertMany<'a> { type Future = InsertManyFuture; @@ -105,14 +196,14 @@ impl<'a> Action for InsertMany<'a> { .unwrap_or(true); let encrypted = self.coll.client().should_auto_encrypt().await; - let mut cumulative_failure: Option = None; - let mut error_labels: HashSet = Default::default(); - let mut cumulative_result: Option = None; - - let mut n_attempted = 0; + let mut cumulative = Cumulative::new(); - while n_attempted < ds.len() { - let docs: Vec<_> = ds.iter().skip(n_attempted).map(Deref::deref).collect(); + while cumulative.offset < ds.len() { + let docs: Vec<_> = ds + .iter() + .skip(cumulative.offset) + .map(Deref::deref) + .collect(); let insert = Op::new(self.coll.namespace(), docs, self.options.clone(), encrypted); match self @@ -122,72 +213,22 @@ impl<'a> Action for InsertMany<'a> { .await { Ok(result) => { - let current_batch_size = result.inserted_ids.len(); - - let cumulative_result = - cumulative_result.get_or_insert_with(InsertManyResult::new); - for (index, id) in result.inserted_ids { - cumulative_result - .inserted_ids - .insert(index + n_attempted, id); - } - - n_attempted += current_batch_size; + cumulative += result; } - Err(e) => { - let labels = e.labels().clone(); - match *e.kind { - ErrorKind::InsertMany(bw) => { - // for ordered inserts this size will be incorrect, but knowing the - // batch size isn't needed for ordered - // failures since we return immediately from - // them anyways. - let current_batch_size = bw.inserted_ids.len() - + bw.write_errors.as_ref().map(|we| we.len()).unwrap_or(0); - - let failure_ref = - cumulative_failure.get_or_insert_with(InsertManyError::new); - if let Some(write_errors) = bw.write_errors { - for err in write_errors { - let index = n_attempted + err.index; - - failure_ref - .write_errors - .get_or_insert_with(Default::default) - .push(IndexedWriteError { index, ..err }); - } - } - - if let Some(wc_error) = bw.write_concern_error { - failure_ref.write_concern_error = Some(wc_error); - } - - error_labels.extend(labels); - - if ordered { - // this will always be true since we invoked get_or_insert_with - // above. - if let Some(failure) = cumulative_failure { - return Err(Error::new( - ErrorKind::InsertMany(failure), - Some(error_labels), - )); - } - } - n_attempted += current_batch_size; + Err(e) => match &*e.kind { + ErrorKind::InsertMany(_) => { + cumulative.labels(e.labels()); + if let ErrorKind::InsertMany(bw) = *e.kind { + cumulative += bw; + } + if ordered { + break; } - _ => return Err(e), } - } + _ => return Err(e), + }, } } - - match cumulative_failure { - Some(failure) => Err(Error::new( - ErrorKind::InsertMany(failure), - Some(error_labels), - )), - None => Ok(cumulative_result.unwrap_or_else(InsertManyResult::new)), - } + cumulative.result() } } diff --git a/src/error.rs b/src/error.rs index f3d798c01..59f4cd95b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -880,16 +880,6 @@ pub struct InsertManyError { pub(crate) inserted_ids: HashMap, } -impl InsertManyError { - pub(crate) fn new() -> Self { - InsertManyError { - write_errors: None, - write_concern_error: None, - inserted_ids: Default::default(), - } - } -} - /// An error that occurred when trying to execute a write operation. #[derive(Clone, Debug, Serialize, Deserialize)] #[non_exhaustive] diff --git a/src/results.rs b/src/results.rs index 4e8a82e84..fe2d06142 100644 --- a/src/results.rs +++ b/src/results.rs @@ -45,14 +45,6 @@ pub struct InsertManyResult { pub inserted_ids: HashMap, } -impl InsertManyResult { - pub(crate) fn new() -> Self { - InsertManyResult { - inserted_ids: HashMap::new(), - } - } -} - /// The result of a [`Collection::update_one`](../struct.Collection.html#method.update_one) or /// [`Collection::update_many`](../struct.Collection.html#method.update_many) operation. #[skip_serializing_none]