From 6535303ee1b9691def1bfb3b5059f5458fb29c04 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Wed, 22 Jan 2025 12:37:14 +0200 Subject: [PATCH] libsql: WAL sync checkpoint support This patch adds support for checkpointing during WAL sync, which allows us to sync multiple checkpoint generations from the server. The protocol is as follows: 1. A client uses the pull endpoint to fetch frames. 2. When we reach the end of a generation, the server returns "I am a teapot" (yes, really) with a JSON containing the maximum generation number on the server. 3. If we need to pull more generations, we first checkpoint the WAL on the client, and then continue pulling frames from the newer generation. --- libsql/src/local/connection.rs | 33 +++++++- libsql/src/sync.rs | 147 ++++++++++++++++++++++++--------- libsql/src/sync/test.rs | 6 +- 3 files changed, 142 insertions(+), 44 deletions(-) diff --git a/libsql/src/local/connection.rs b/libsql/src/local/connection.rs index 000112c999..73a5223e11 100644 --- a/libsql/src/local/connection.rs +++ b/libsql/src/local/connection.rs @@ -9,6 +9,7 @@ use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction}; use crate::TransactionBehavior; use libsql_sys::ffi; +use std::cell::RefCell; use std::{ffi::c_int, fmt, path::Path, sync::Arc}; /// A connection to a libSQL database. @@ -451,6 +452,20 @@ impl Connection { } } + pub(crate) fn wal_checkpoint(&self, truncate: bool) -> Result<()> { + let rc = unsafe { libsql_sys::ffi::sqlite3_wal_checkpoint_v2(self.handle(), std::ptr::null(), truncate as i32, std::ptr::null_mut(), std::ptr::null_mut()) }; + if rc != 0 { + let err_msg = unsafe { libsql_sys::ffi::sqlite3_errmsg(self.handle()) }; + let err_msg = unsafe { std::ffi::CStr::from_ptr(err_msg) }; + let err_msg = err_msg.to_string_lossy().to_string(); + return Err(crate::errors::Error::SqliteFailure( + rc as std::ffi::c_int, + format!("Failed to checkpoint WAL: {}", err_msg), + )); + } + Ok(()) + } + pub(crate) fn wal_frame_count(&self) -> u32 { let mut max_frame_no: std::os::raw::c_uint = 0; unsafe { libsql_sys::ffi::libsql_wal_frame_count(self.handle(), &mut max_frame_no) }; @@ -537,18 +552,34 @@ impl Connection { pub(crate) fn wal_insert_handle(&self) -> Result> { self.wal_insert_begin()?; - Ok(WalInsertHandle { conn: self }) + Ok(WalInsertHandle { conn: self, in_session: RefCell::new(true) }) } } pub(crate) struct WalInsertHandle<'a> { conn: &'a Connection, + in_session: RefCell } impl WalInsertHandle<'_> { pub fn insert(&self, frame: &[u8]) -> Result<()> { + assert!(*self.in_session.borrow()); self.conn.wal_insert_frame(frame) } + + pub fn begin(&self) -> Result<()> { + assert!(!*self.in_session.borrow()); + self.conn.wal_insert_begin()?; + self.in_session.replace(true); + Ok(()) + } + + pub fn end(&self) -> Result<()> { + assert!(*self.in_session.borrow()); + self.conn.wal_insert_end()?; + self.in_session.replace(false); + Ok(()) + } } impl Drop for WalInsertHandle<'_> { diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index da25c2f84f..186bfcf1fe 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -61,16 +61,23 @@ impl SyncError { } } +pub enum PullResult { + /// A frame was successfully pulled. + Frame(Bytes), + /// We've reached the end of the generation. + EndOfGeneration { max_generation: u32 }, +} + pub struct SyncContext { db_path: String, client: hyper::Client, sync_url: String, auth_token: Option, max_retries: usize, + /// The current durable generation. + durable_generation: u32, /// Represents the max_frame_no from the server. durable_frame_num: u32, - /// Represents the current checkpoint generation. - generation: u32, } impl SyncContext { @@ -96,8 +103,8 @@ impl SyncContext { auth_token, max_retries: DEFAULT_MAX_RETRIES, client, + durable_generation: 1, durable_frame_num: 0, - generation: 1, }; if let Err(e) = me.read_metadata().await { @@ -115,7 +122,7 @@ impl SyncContext { &mut self, generation: u32, frame_no: u32, - ) -> Result> { + ) -> Result { let uri = format!( "{}/sync/{}/{}/{}", self.sync_url, @@ -124,13 +131,7 @@ impl SyncContext { frame_no + 1 ); tracing::debug!("pulling frame"); - match self.pull_with_retry(uri, self.max_retries).await? { - Some(frame) => { - self.durable_frame_num = frame_no; - Ok(Some(frame)) - } - None => Ok(None), - } + self.pull_with_retry(uri, self.max_retries).await } #[tracing::instrument(skip(self, frame))] @@ -149,7 +150,7 @@ impl SyncContext { ); tracing::debug!("pushing frame"); - let durable_frame_num = self.push_with_retry(uri, frame, self.max_retries).await?; + let (generation, durable_frame_num) = self.push_with_retry(uri, frame, self.max_retries).await?; if durable_frame_num > frame_no { tracing::error!( @@ -178,12 +179,14 @@ impl SyncContext { tracing::debug!(?durable_frame_num, "frame successfully pushed"); // Update our last known max_frame_no from the server. + tracing::debug!(?generation, ?durable_frame_num, "updating remote generation and durable_frame_num"); + self.durable_generation = generation; self.durable_frame_num = durable_frame_num; Ok(durable_frame_num) } - async fn push_with_retry(&self, uri: String, frame: Bytes, max_retries: usize) -> Result { + async fn push_with_retry(&self, uri: String, frame: Bytes, max_retries: usize) -> Result<(u32, u32)> { let mut nr_retries = 0; loop { let mut req = http::Request::post(uri.clone()); @@ -213,6 +216,14 @@ impl SyncContext { let resp = serde_json::from_slice::(&res_body[..]) .map_err(SyncError::JsonDecode)?; + let generation = resp + .get("generation") + .ok_or_else(|| SyncError::JsonValue(resp.clone()))?; + + let generation = generation + .as_u64() + .ok_or_else(|| SyncError::JsonValue(generation.clone()))?; + let max_frame_no = resp .get("max_frame_no") .ok_or_else(|| SyncError::JsonValue(resp.clone()))?; @@ -221,7 +232,7 @@ impl SyncContext { .as_u64() .ok_or_else(|| SyncError::JsonValue(max_frame_no.clone()))?; - return Ok(max_frame_no as u32); + return Ok((generation as u32, max_frame_no as u32)); } // If we've retried too many times or the error is not a server error, @@ -244,7 +255,7 @@ impl SyncContext { } } - async fn pull_with_retry(&self, uri: String, max_retries: usize) -> Result> { + async fn pull_with_retry(&self, uri: String, max_retries: usize) -> Result { let mut nr_retries = 0; loop { let mut req = http::Request::builder().method("GET").uri(uri.clone()); @@ -268,10 +279,27 @@ impl SyncContext { let frame = hyper::body::to_bytes(res.into_body()) .await .map_err(SyncError::HttpBody)?; - return Ok(Some(frame)); + return Ok(PullResult::Frame(frame)); + } + if res.status() == StatusCode::BAD_REQUEST { + let res_body = hyper::body::to_bytes(res.into_body()) + .await + .map_err(SyncError::HttpBody)?; + + let resp = serde_json::from_slice::(&res_body[..]) + .map_err(SyncError::JsonDecode)?; + + let generation = resp + .get("generation") + .ok_or_else(|| SyncError::JsonValue(resp.clone()))?; + + let generation = generation + .as_u64() + .ok_or_else(|| SyncError::JsonValue(generation.clone()))?; + return Ok(PullResult::EndOfGeneration { max_generation: generation as u32 }); } if res.status() == StatusCode::BAD_REQUEST { - return Ok(None); + return Err(SyncError::PullFrame(res.status(), "Bad Request".to_string()).into()); } // If we've retried too many times or the error is not a server error, // return the error. @@ -293,12 +321,18 @@ impl SyncContext { } } + + pub(crate) fn next_generation(&mut self) { + self.durable_generation += 1; + self.durable_frame_num = 0; + } + pub(crate) fn durable_frame_num(&self) -> u32 { self.durable_frame_num } - pub(crate) fn generation(&self) -> u32 { - self.generation + pub(crate) fn durable_generation(&self) -> u32 { + self.durable_generation } pub(crate) async fn write_metadata(&mut self) -> Result<()> { @@ -308,7 +342,7 @@ impl SyncContext { hash: 0, version: METADATA_VERSION, durable_frame_num: self.durable_frame_num, - generation: self.generation, + generation: self.durable_generation, }; metadata.set_hash(); @@ -350,8 +384,8 @@ impl SyncContext { metadata ); + self.durable_generation = metadata.generation; self.durable_frame_num = metadata.durable_frame_num; - self.generation = metadata.generation; Ok(()) } @@ -436,10 +470,7 @@ pub async fn sync_offline( sync_ctx: &mut SyncContext, conn: &Connection, ) -> Result { - let durable_frame_no = sync_ctx.durable_frame_num(); - let max_frame_no = conn.wal_frame_count(); - - if max_frame_no > durable_frame_no { + if is_ahead_of_remote(&sync_ctx, &conn) { match try_push(sync_ctx, conn).await { Ok(rep) => Ok(rep), Err(Error::Sync(err)) => { @@ -475,6 +506,11 @@ pub async fn sync_offline( }) } +fn is_ahead_of_remote(sync_ctx: &SyncContext, conn: &Connection) -> bool { + let max_local_frame = conn.wal_frame_count(); + max_local_frame > sync_ctx.durable_frame_num() +} + async fn try_push( sync_ctx: &mut SyncContext, conn: &Connection, @@ -496,7 +532,7 @@ async fn try_push( }); } - let generation = sync_ctx.generation(); // TODO: Probe from WAL. + let generation = sync_ctx.durable_generation(); let start_frame_no = sync_ctx.durable_frame_num() + 1; let end_frame_no = max_frame_no; @@ -532,29 +568,60 @@ async fn try_pull( sync_ctx: &mut SyncContext, conn: &Connection, ) -> Result { - let generation = sync_ctx.generation(); - let mut frame_no = sync_ctx.durable_frame_num() + 1; - let insert_handle = conn.wal_insert_handle()?; + let mut err = None; + loop { + let generation = sync_ctx.durable_generation(); + let frame_no = sync_ctx.durable_frame_num() + 1; match sync_ctx.pull_one_frame(generation, frame_no).await { - Ok(Some(frame)) => { + Ok(PullResult::Frame(frame)) => { insert_handle.insert(&frame)?; - frame_no += 1; + sync_ctx.durable_frame_num = frame_no; } - Ok(None) => { + Ok(PullResult::EndOfGeneration { max_generation }) => { + // If there are no more generations to pull, we're done. + if generation >= max_generation { + break; + } + insert_handle.end()?; sync_ctx.write_metadata().await?; - return Ok(crate::database::Replicated { - frame_no: None, - frames_synced: 1, - }); - } - Err(err) => { - tracing::debug!("pull_one_frame error: {:?}", err); + + // TODO: Make this crash-proof. + conn.wal_checkpoint(true)?; + + sync_ctx.next_generation(); sync_ctx.write_metadata().await?; - return Err(err); + + insert_handle.begin()?; + } + Err(e) => { + tracing::debug!("pull_one_frame error: {:?}", e); + err.replace(e); + break; } } } + // This is crash-proof because we: + // + // 1. Write WAL frame first + // 2. Write new max frame to temporary metadata + // 3. Atomically rename the temporary metadata to the real metadata + // + // If we crash before metadata rename completes, the old metadata still + // points to last successful frame, allowing safe retry from that point. + // If we happen to have the frame already in the WAL, it's fine to re-pull + // because append locally is idempotent. + insert_handle.end()?; + sync_ctx.write_metadata().await?; + + if let Some(err) = err { + Err(err) + } else { + Ok(crate::database::Replicated { + frame_no: None, + frames_synced: 1, + }) + } } diff --git a/libsql/src/sync/test.rs b/libsql/src/sync/test.rs index aec89ef3b4..0e43b29bce 100644 --- a/libsql/src/sync/test.rs +++ b/libsql/src/sync/test.rs @@ -34,7 +34,7 @@ async fn test_sync_context_push_frame() { // Verify internal state was updated assert_eq!(sync_ctx.durable_frame_num(), 0); - assert_eq!(sync_ctx.generation(), 1); + assert_eq!(sync_ctx.durable_generation(), 1); assert_eq!(server.frame_count(), 1); } @@ -129,7 +129,7 @@ async fn test_sync_context_corrupted_metadata() { // Verify that the context was reset to default values assert_eq!(sync_ctx.durable_frame_num(), 0); - assert_eq!(sync_ctx.generation(), 1); + assert_eq!(sync_ctx.durable_generation(), 1); } #[tokio::test] @@ -174,7 +174,7 @@ async fn test_sync_restarts_with_lower_max_frame_no() { // Verify that the context was set to new fake values. assert_eq!(sync_ctx.durable_frame_num(), 3); - assert_eq!(sync_ctx.generation(), 1); + assert_eq!(sync_ctx.durable_generation(), 1); let frame_no = sync_ctx.durable_frame_num() + 1; // This push should fail because we are ahead of the server and thus should get an invalid