Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WAL sync checkpointing support #1928

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion libsql/src/local/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction};
use crate::TransactionBehavior;

use libsql_sys::ffi;
use std::cell::RefCell;
use std::{ffi::c_int, fmt, path::Path, sync::Arc};

/// A connection to a libSQL database.
Expand Down Expand Up @@ -451,6 +452,20 @@ impl Connection {
}
}

pub(crate) fn wal_checkpoint(&self, truncate: bool) -> Result<()> {
let rc = unsafe { libsql_sys::ffi::sqlite3_wal_checkpoint_v2(self.handle(), std::ptr::null(), truncate as i32, std::ptr::null_mut(), std::ptr::null_mut()) };
if rc != 0 {
let err_msg = unsafe { libsql_sys::ffi::sqlite3_errmsg(self.handle()) };
let err_msg = unsafe { std::ffi::CStr::from_ptr(err_msg) };
let err_msg = err_msg.to_string_lossy().to_string();
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
format!("Failed to checkpoint WAL: {}", err_msg),
));
}
Ok(())
}

pub(crate) fn wal_frame_count(&self) -> u32 {
let mut max_frame_no: std::os::raw::c_uint = 0;
unsafe { libsql_sys::ffi::libsql_wal_frame_count(self.handle(), &mut max_frame_no) };
Expand Down Expand Up @@ -537,18 +552,34 @@ impl Connection {

pub(crate) fn wal_insert_handle(&self) -> Result<WalInsertHandle<'_>> {
self.wal_insert_begin()?;
Ok(WalInsertHandle { conn: self })
Ok(WalInsertHandle { conn: self, in_session: RefCell::new(true) })
}
}

pub(crate) struct WalInsertHandle<'a> {
conn: &'a Connection,
in_session: RefCell<bool>
}

impl WalInsertHandle<'_> {
pub fn insert(&self, frame: &[u8]) -> Result<()> {
assert!(*self.in_session.borrow());
self.conn.wal_insert_frame(frame)
}

pub fn begin(&self) -> Result<()> {
assert!(!*self.in_session.borrow());
self.conn.wal_insert_begin()?;
self.in_session.replace(true);
Ok(())
}

pub fn end(&self) -> Result<()> {
assert!(*self.in_session.borrow());
self.conn.wal_insert_end()?;
self.in_session.replace(false);
Ok(())
}
}

impl Drop for WalInsertHandle<'_> {
Expand Down
147 changes: 107 additions & 40 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,23 @@ impl SyncError {
}
}

pub enum PullResult {
/// A frame was successfully pulled.
Frame(Bytes),
/// We've reached the end of the generation.
EndOfGeneration { max_generation: u32 },
}

pub struct SyncContext {
db_path: String,
client: hyper::Client<ConnectorService, Body>,
sync_url: String,
auth_token: Option<HeaderValue>,
max_retries: usize,
/// The current durable generation.
durable_generation: u32,
/// Represents the max_frame_no from the server.
durable_frame_num: u32,
/// Represents the current checkpoint generation.
generation: u32,
}

impl SyncContext {
Expand All @@ -96,8 +103,8 @@ impl SyncContext {
auth_token,
max_retries: DEFAULT_MAX_RETRIES,
client,
durable_generation: 1,
durable_frame_num: 0,
generation: 1,
};

if let Err(e) = me.read_metadata().await {
Expand All @@ -115,7 +122,7 @@ impl SyncContext {
&mut self,
generation: u32,
frame_no: u32,
) -> Result<Option<Bytes>> {
) -> Result<PullResult> {
let uri = format!(
"{}/sync/{}/{}/{}",
self.sync_url,
Expand All @@ -124,13 +131,7 @@ impl SyncContext {
frame_no + 1
);
tracing::debug!("pulling frame");
match self.pull_with_retry(uri, self.max_retries).await? {
Some(frame) => {
self.durable_frame_num = frame_no;
Ok(Some(frame))
}
None => Ok(None),
}
self.pull_with_retry(uri, self.max_retries).await
}

#[tracing::instrument(skip(self, frame))]
Expand All @@ -149,7 +150,7 @@ impl SyncContext {
);
tracing::debug!("pushing frame");

let durable_frame_num = self.push_with_retry(uri, frame, self.max_retries).await?;
let (generation, durable_frame_num) = self.push_with_retry(uri, frame, self.max_retries).await?;

if durable_frame_num > frame_no {
tracing::error!(
Expand Down Expand Up @@ -178,12 +179,14 @@ impl SyncContext {
tracing::debug!(?durable_frame_num, "frame successfully pushed");

// Update our last known max_frame_no from the server.
tracing::debug!(?generation, ?durable_frame_num, "updating remote generation and durable_frame_num");
self.durable_generation = generation;
self.durable_frame_num = durable_frame_num;

Ok(durable_frame_num)
}

async fn push_with_retry(&self, uri: String, frame: Bytes, max_retries: usize) -> Result<u32> {
async fn push_with_retry(&self, uri: String, frame: Bytes, max_retries: usize) -> Result<(u32, u32)> {
let mut nr_retries = 0;
loop {
let mut req = http::Request::post(uri.clone());
Expand Down Expand Up @@ -213,6 +216,14 @@ impl SyncContext {
let resp = serde_json::from_slice::<serde_json::Value>(&res_body[..])
.map_err(SyncError::JsonDecode)?;

let generation = resp
.get("generation")
.ok_or_else(|| SyncError::JsonValue(resp.clone()))?;

let generation = generation
.as_u64()
.ok_or_else(|| SyncError::JsonValue(generation.clone()))?;

let max_frame_no = resp
.get("max_frame_no")
.ok_or_else(|| SyncError::JsonValue(resp.clone()))?;
Expand All @@ -221,7 +232,7 @@ impl SyncContext {
.as_u64()
.ok_or_else(|| SyncError::JsonValue(max_frame_no.clone()))?;

return Ok(max_frame_no as u32);
return Ok((generation as u32, max_frame_no as u32));
}

// If we've retried too many times or the error is not a server error,
Expand All @@ -244,7 +255,7 @@ impl SyncContext {
}
}

async fn pull_with_retry(&self, uri: String, max_retries: usize) -> Result<Option<Bytes>> {
async fn pull_with_retry(&self, uri: String, max_retries: usize) -> Result<PullResult> {
let mut nr_retries = 0;
loop {
let mut req = http::Request::builder().method("GET").uri(uri.clone());
Expand All @@ -268,10 +279,27 @@ impl SyncContext {
let frame = hyper::body::to_bytes(res.into_body())
.await
.map_err(SyncError::HttpBody)?;
return Ok(Some(frame));
return Ok(PullResult::Frame(frame));
}
if res.status() == StatusCode::BAD_REQUEST {
let res_body = hyper::body::to_bytes(res.into_body())
.await
.map_err(SyncError::HttpBody)?;

let resp = serde_json::from_slice::<serde_json::Value>(&res_body[..])
.map_err(SyncError::JsonDecode)?;

let generation = resp
.get("generation")
.ok_or_else(|| SyncError::JsonValue(resp.clone()))?;

let generation = generation
.as_u64()
.ok_or_else(|| SyncError::JsonValue(generation.clone()))?;
return Ok(PullResult::EndOfGeneration { max_generation: generation as u32 });
}
if res.status() == StatusCode::BAD_REQUEST {
return Ok(None);
return Err(SyncError::PullFrame(res.status(), "Bad Request".to_string()).into());
}
// If we've retried too many times or the error is not a server error,
// return the error.
Expand All @@ -293,12 +321,18 @@ impl SyncContext {
}
}


pub(crate) fn next_generation(&mut self) {
self.durable_generation += 1;
self.durable_frame_num = 0;
}

pub(crate) fn durable_frame_num(&self) -> u32 {
self.durable_frame_num
}

pub(crate) fn generation(&self) -> u32 {
self.generation
pub(crate) fn durable_generation(&self) -> u32 {
self.durable_generation
}

pub(crate) async fn write_metadata(&mut self) -> Result<()> {
Expand All @@ -308,7 +342,7 @@ impl SyncContext {
hash: 0,
version: METADATA_VERSION,
durable_frame_num: self.durable_frame_num,
generation: self.generation,
generation: self.durable_generation,
};

metadata.set_hash();
Expand Down Expand Up @@ -350,8 +384,8 @@ impl SyncContext {
metadata
);

self.durable_generation = metadata.generation;
self.durable_frame_num = metadata.durable_frame_num;
self.generation = metadata.generation;

Ok(())
}
Expand Down Expand Up @@ -436,10 +470,7 @@ pub async fn sync_offline(
sync_ctx: &mut SyncContext,
conn: &Connection,
) -> Result<crate::database::Replicated> {
let durable_frame_no = sync_ctx.durable_frame_num();
let max_frame_no = conn.wal_frame_count();

if max_frame_no > durable_frame_no {
if is_ahead_of_remote(&sync_ctx, &conn) {
match try_push(sync_ctx, conn).await {
Ok(rep) => Ok(rep),
Err(Error::Sync(err)) => {
Expand Down Expand Up @@ -475,6 +506,11 @@ pub async fn sync_offline(
})
}

fn is_ahead_of_remote(sync_ctx: &SyncContext, conn: &Connection) -> bool {
let max_local_frame = conn.wal_frame_count();
max_local_frame > sync_ctx.durable_frame_num()
}

async fn try_push(
sync_ctx: &mut SyncContext,
conn: &Connection,
Expand All @@ -496,7 +532,7 @@ async fn try_push(
});
}

let generation = sync_ctx.generation(); // TODO: Probe from WAL.
let generation = sync_ctx.durable_generation();
let start_frame_no = sync_ctx.durable_frame_num() + 1;
let end_frame_no = max_frame_no;

Expand Down Expand Up @@ -532,29 +568,60 @@ async fn try_pull(
sync_ctx: &mut SyncContext,
conn: &Connection,
) -> Result<crate::database::Replicated> {
let generation = sync_ctx.generation();
let mut frame_no = sync_ctx.durable_frame_num() + 1;

let insert_handle = conn.wal_insert_handle()?;

let mut err = None;

loop {
let generation = sync_ctx.durable_generation();
let frame_no = sync_ctx.durable_frame_num() + 1;
match sync_ctx.pull_one_frame(generation, frame_no).await {
Ok(Some(frame)) => {
Ok(PullResult::Frame(frame)) => {
insert_handle.insert(&frame)?;
frame_no += 1;
sync_ctx.durable_frame_num = frame_no;
}
Comment on lines +579 to 582
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need to insert here, apparently. Can we instead of maintaining the invariants of WalInsertHandle with a RefCell and check at runtime, use inherited mutability at comptime do the job by creating and dropping the handle on each frame? I introduced WalInsertHandle in a previous PR to be able to do that, but I might have encoded the wrong semantics. What you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand what you mean. If you want, please submit a PR to do what you propose and happy to merge if it's cleaner.

Ok(None) => {
Ok(PullResult::EndOfGeneration { max_generation }) => {
// If there are no more generations to pull, we're done.
if generation >= max_generation {
break;
}
insert_handle.end()?;
sync_ctx.write_metadata().await?;
return Ok(crate::database::Replicated {
frame_no: None,
frames_synced: 1,
});
}
Err(err) => {
tracing::debug!("pull_one_frame error: {:?}", err);

// TODO: Make this crash-proof.
conn.wal_checkpoint(true)?;

sync_ctx.next_generation();
sync_ctx.write_metadata().await?;
return Err(err);

insert_handle.begin()?;
}
Err(e) => {
tracing::debug!("pull_one_frame error: {:?}", e);
err.replace(e);
break;
}
}
}
// This is crash-proof because we:
//
// 1. Write WAL frame first
// 2. Write new max frame to temporary metadata
// 3. Atomically rename the temporary metadata to the real metadata
//
// If we crash before metadata rename completes, the old metadata still
// points to last successful frame, allowing safe retry from that point.
// If we happen to have the frame already in the WAL, it's fine to re-pull
// because append locally is idempotent.
insert_handle.end()?;
sync_ctx.write_metadata().await?;

if let Some(err) = err {
Err(err)
} else {
Ok(crate::database::Replicated {
frame_no: None,
frames_synced: 1,
})
}
}
7 changes: 4 additions & 3 deletions libsql/src/sync/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async fn test_sync_context_push_frame() {

// Verify internal state was updated
assert_eq!(sync_ctx.durable_frame_num(), 0);
assert_eq!(sync_ctx.generation(), 1);
assert_eq!(sync_ctx.durable_generation(), 1);
assert_eq!(server.frame_count(), 1);
}

Expand Down Expand Up @@ -129,7 +129,7 @@ async fn test_sync_context_corrupted_metadata() {

// Verify that the context was reset to default values
assert_eq!(sync_ctx.durable_frame_num(), 0);
assert_eq!(sync_ctx.generation(), 1);
assert_eq!(sync_ctx.durable_generation(), 1);
}

#[tokio::test]
Expand Down Expand Up @@ -174,7 +174,7 @@ async fn test_sync_restarts_with_lower_max_frame_no() {

// Verify that the context was set to new fake values.
assert_eq!(sync_ctx.durable_frame_num(), 3);
assert_eq!(sync_ctx.generation(), 1);
assert_eq!(sync_ctx.durable_generation(), 1);

let frame_no = sync_ctx.durable_frame_num() + 1;
// This push should fail because we are ahead of the server and thus should get an invalid
Expand Down Expand Up @@ -376,6 +376,7 @@ impl MockServer {
if req.uri().path().contains("/sync/") {
// Return the max_frame_no that has been accepted
let response = serde_json::json!({
"generation": 1,
"max_frame_no": current_count
});

Expand Down
Loading