Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push up to 128 frames in sync #1950

Merged
merged 4 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions libsql/src/database/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ impl Builder<()> {
connector: None,
read_your_writes: true,
remote_writes: false,
push_batch_size: 0,
},
}
}
Expand Down Expand Up @@ -524,6 +525,7 @@ cfg_sync! {
connector: Option<crate::util::ConnectorService>,
remote_writes: bool,
read_your_writes: bool,
push_batch_size: u32,
}

impl Builder<SyncedDatabase> {
Expand All @@ -543,6 +545,11 @@ cfg_sync! {
self
}

pub fn set_push_batch_size(mut self, v: u32) -> Builder<SyncedDatabase> {
self.inner.push_batch_size = v;
self
}

/// Provide a custom http connector that will be used to create http connections.
pub fn connector<C>(mut self, connector: C) -> Builder<SyncedDatabase>
where
Expand Down Expand Up @@ -570,6 +577,7 @@ cfg_sync! {
connector,
remote_writes,
read_your_writes,
push_batch_size,
} = self.inner;

let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned();
Expand All @@ -596,6 +604,10 @@ cfg_sync! {
)
.await?;

if push_batch_size > 0 {
db.sync_ctx.as_ref().unwrap().lock().await.set_push_batch_size(push_batch_size);
}

Ok(Database {
db_type: DbType::Offline {
db,
Expand Down
43 changes: 30 additions & 13 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod transaction;
const METADATA_VERSION: u32 = 0;

const DEFAULT_MAX_RETRIES: usize = 5;
const DEFAULT_PUSH_BATCH_SIZE: u32 = 128;

#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
Expand Down Expand Up @@ -74,6 +75,7 @@ pub struct SyncContext {
sync_url: String,
auth_token: Option<HeaderValue>,
max_retries: usize,
push_batch_size: u32,
/// The current durable generation.
durable_generation: u32,
/// Represents the max_frame_no from the server.
Expand Down Expand Up @@ -102,6 +104,7 @@ impl SyncContext {
sync_url,
auth_token,
max_retries: DEFAULT_MAX_RETRIES,
push_batch_size: DEFAULT_PUSH_BATCH_SIZE,
client,
durable_generation: 1,
durable_frame_num: 0,
Expand All @@ -117,6 +120,10 @@ impl SyncContext {
Ok(me)
}

pub fn set_push_batch_size(&mut self, push_batch_size: u32) {
self.push_batch_size = push_batch_size;
}

#[tracing::instrument(skip(self))]
pub(crate) async fn pull_one_frame(
&mut self,
Expand All @@ -134,25 +141,26 @@ impl SyncContext {
self.pull_with_retry(uri, self.max_retries).await
}

#[tracing::instrument(skip(self, frame))]
pub(crate) async fn push_one_frame(
#[tracing::instrument(skip(self, frames))]
pub(crate) async fn push_frames(
&mut self,
frame: Bytes,
frames: Bytes,
generation: u32,
frame_no: u32,
frames_count: u32,
) -> Result<u32> {
let uri = format!(
"{}/sync/{}/{}/{}",
self.sync_url,
generation,
frame_no,
frame_no + 1
frame_no + frames_count
);
tracing::debug!("pushing frame");

let (generation, durable_frame_num) = self.push_with_retry(uri, frame, self.max_retries).await?;
let (generation, durable_frame_num) = self.push_with_retry(uri, frames, self.max_retries).await?;

if durable_frame_num > frame_no {
if durable_frame_num > frame_no + frames_count - 1 {
tracing::error!(
"server returned durable_frame_num larger than what we sent: sent={}, got={}",
frame_no,
Expand All @@ -162,7 +170,7 @@ impl SyncContext {
return Err(SyncError::InvalidPushFrameNoHigh(frame_no, durable_frame_num).into());
}

if durable_frame_num < frame_no {
if durable_frame_num < frame_no + frames_count - 1 {
// Update our knowledge of where the server is at frame wise.
self.durable_frame_num = durable_frame_num;

Expand All @@ -186,7 +194,7 @@ impl SyncContext {
Ok(durable_frame_num)
}

async fn push_with_retry(&self, uri: String, frame: Bytes, max_retries: usize) -> Result<(u32, u32)> {
async fn push_with_retry(&self, uri: String, body: Bytes, max_retries: usize) -> Result<(u32, u32)> {
let mut nr_retries = 0;
loop {
let mut req = http::Request::post(uri.clone());
Expand All @@ -200,7 +208,7 @@ impl SyncContext {
None => {}
}

let req = req.body(frame.clone().into()).expect("valid body");
let req = req.body(body.clone().into()).expect("valid body");

let res = self
.client
Expand Down Expand Up @@ -537,19 +545,28 @@ async fn try_push(

let mut frame_no = start_frame_no;
while frame_no <= end_frame_no {
let frame = conn.wal_get_frame(frame_no, page_size)?;
let batch_size = sync_ctx.push_batch_size.min(end_frame_no - frame_no + 1);
let mut frames = conn.wal_get_frame(frame_no, page_size)?;
if batch_size > 1 {
frames.reserve((batch_size - 1) as usize * frames.len());
}
for idx in 1..batch_size {
let frame = conn.wal_get_frame(frame_no + idx, page_size)?;
frames.extend_from_slice(frame.as_ref())
}

// The server returns its maximum frame number. To avoid resending
// frames the server already knows about, we need to update the
// frame number to the one returned by the server.
let max_frame_no = sync_ctx
.push_one_frame(frame.freeze(), generation, frame_no)
.push_frames(frames.freeze(), generation, frame_no, batch_size)
.await?;

if max_frame_no > frame_no {
frame_no = max_frame_no;
frame_no = max_frame_no + 1;
} else {
frame_no += batch_size;
}
frame_no += 1;
}

sync_ctx.write_metadata().await?;
Expand Down
18 changes: 9 additions & 9 deletions libsql/src/sync/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async fn test_sync_context_push_frame() {
let mut sync_ctx = sync_ctx;

// Push a frame and verify the response
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0); // First frame should return max_frame_no = 0

Expand Down Expand Up @@ -56,7 +56,7 @@ async fn test_sync_context_with_auth() {
let frame = Bytes::from("test frame with auth");
let mut sync_ctx = sync_ctx;

let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
Expand All @@ -82,7 +82,7 @@ async fn test_sync_context_multiple_frames() {
// Push multiple frames and verify incrementing frame numbers
for i in 0..3 {
let frame = Bytes::from(format!("frame data {}", i));
let durable_frame = sync_ctx.push_one_frame(frame, 1, i).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame, 1, i, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, i);
assert_eq!(sync_ctx.durable_frame_num(), i);
Expand All @@ -108,7 +108,7 @@ async fn test_sync_context_corrupted_metadata() {

let mut sync_ctx = sync_ctx;
let frame = Bytes::from("test frame data");
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
Expand Down Expand Up @@ -152,7 +152,7 @@ async fn test_sync_restarts_with_lower_max_frame_no() {

let mut sync_ctx = sync_ctx;
let frame = Bytes::from("test frame data");
let durable_frame = sync_ctx.push_one_frame(frame.clone(), 1, 0).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame.clone(), 1, 0, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
Expand Down Expand Up @@ -180,14 +180,14 @@ async fn test_sync_restarts_with_lower_max_frame_no() {
// This push should fail because we are ahead of the server and thus should get an invalid
// frame no error.
sync_ctx
.push_one_frame(frame.clone(), 1, frame_no)
.push_frames(frame.clone(), 1, frame_no, 1)
.await
.unwrap_err();

let frame_no = sync_ctx.durable_frame_num() + 1;
// This then should work because when the last one failed it updated our state of the server
// durable_frame_num and we should then start writing from there.
sync_ctx.push_one_frame(frame, 1, frame_no).await.unwrap();
sync_ctx.push_frames(frame, 1, frame_no, 1).await.unwrap();
}

#[tokio::test]
Expand Down Expand Up @@ -215,7 +215,7 @@ async fn test_sync_context_retry_on_error() {
server.return_error.store(true, Ordering::SeqCst);

// First attempt should fail but retry
let result = sync_ctx.push_one_frame(frame.clone(), 1, 0).await;
let result = sync_ctx.push_frames(frame.clone(), 1, 0, 1).await;
assert!(result.is_err());

// Advance time to trigger retries faster
Expand All @@ -228,7 +228,7 @@ async fn test_sync_context_retry_on_error() {
server.return_error.store(false, Ordering::SeqCst);

// Next attempt should succeed
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
Expand Down
Loading