Skip to content

Commit

Permalink
libsql: add http_request_callback for replicas
Browse files Browse the repository at this point in the history
  • Loading branch information
LucioFranco committed Feb 21, 2024
1 parent 97efe98 commit cd26042
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 14 deletions.
11 changes: 10 additions & 1 deletion libsql/examples/local_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@ use libsql::{
async fn main() {
tracing_subscriber::fmt::init();

let db = Builder::new_local_replica("test.db").build().await.unwrap();
let db = Builder::new_local_replica("test.db")
.http_request_callback(|r| {
let _uri = r.uri_mut();

// You can modify any part of the http request you would like including headers
// and the URI.
})
.build()
.await
.unwrap();
let conn = db.connect().unwrap();

let args = std::env::args().collect::<Vec<String>>();
Expand Down
6 changes: 4 additions & 2 deletions libsql/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ cfg_replication! {
auth_token,
None,
OpenFlags::default(),
encryption_config.clone()
encryption_config.clone(),
None
).await?;

Ok(Database {
Expand Down Expand Up @@ -307,7 +308,8 @@ cfg_replication! {
version,
read_your_writes,
encryption_config.clone(),
periodic_sync
periodic_sync,
None
).await?;

Ok(Database {
Expand Down
32 changes: 29 additions & 3 deletions libsql/src/database/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ impl Builder<()> {
},
encryption_config: None,
read_your_writes: false,
periodic_sync: None
periodic_sync: None,
http_request_callback: None
},
}
}
Expand All @@ -66,6 +67,7 @@ impl Builder<()> {
flags: crate::OpenFlags::default(),
remote: None,
encryption_config: None,
http_request_callback: None
},
}
}
Expand Down Expand Up @@ -156,6 +158,7 @@ cfg_replication! {
encryption_config: Option<EncryptionConfig>,
read_your_writes: bool,
periodic_sync: Option<std::time::Duration>,
http_request_callback: Option<crate::util::HttpRequestCallback>,
}

/// Local replica configuration type in [`Builder`].
Expand All @@ -164,6 +167,7 @@ cfg_replication! {
flags: crate::OpenFlags,
remote: Option<Remote>,
encryption_config: Option<EncryptionConfig>,
http_request_callback: Option<crate::util::HttpRequestCallback>,
}

impl Builder<RemoteReplica> {
Expand Down Expand Up @@ -203,6 +207,15 @@ cfg_replication! {
self
}

pub fn http_request_callback<F>(mut self, f: F) -> Builder<RemoteReplica>
where
F: Fn(&mut http::Request<()>) + Send + Sync + 'static
{
self.inner.http_request_callback = Some(std::sync::Arc::new(f));
self

}

#[doc(hidden)]
pub fn version(mut self, version: String) -> Builder<RemoteReplica> {
self.inner.remote = self.inner.remote.version(version);
Expand All @@ -222,7 +235,8 @@ cfg_replication! {
},
encryption_config,
read_your_writes,
periodic_sync
periodic_sync,
http_request_callback
} = self.inner;

let connector = if let Some(connector) = connector {
Expand All @@ -248,7 +262,8 @@ cfg_replication! {
version,
read_your_writes,
encryption_config.clone(),
periodic_sync
periodic_sync,
http_request_callback
)
.await?;

Expand All @@ -265,13 +280,23 @@ cfg_replication! {
self
}

pub fn http_request_callback<F>(mut self, f: F) -> Builder<LocalReplica>
where
F: Fn(&mut http::Request<()>) + Send + Sync + 'static
{
self.inner.http_request_callback = Some(std::sync::Arc::new(f));
self

}

/// Build the local embedded replica database.
pub async fn build(self) -> Result<Database> {
let LocalReplica {
path,
flags,
remote,
encryption_config,
http_request_callback
} = self.inner;

let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned();
Expand Down Expand Up @@ -304,6 +329,7 @@ cfg_replication! {
version,
flags,
encryption_config.clone(),
http_request_callback
)
.await?
} else {
Expand Down
11 changes: 9 additions & 2 deletions libsql/src/local/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ impl Database {
false,
encryption_config,
periodic_sync,
None,
)
.await
}
Expand All @@ -79,6 +80,7 @@ impl Database {
read_your_writes: bool,
encryption_config: Option<EncryptionConfig>,
periodic_sync: Option<std::time::Duration>,
http_request_callback: Option<crate::util::HttpRequestCallback>,
) -> Result<Database> {
use std::path::PathBuf;

Expand All @@ -92,6 +94,7 @@ impl Database {
endpoint.as_str().try_into().unwrap(),
auth_token,
version.as_deref(),
http_request_callback,
)
.unwrap();
let path = PathBuf::from(db_path);
Expand Down Expand Up @@ -126,7 +129,8 @@ impl Database {
let path = PathBuf::from(db_path);
let client = LocalClient::new(&path).await.unwrap();

let replicator = EmbeddedReplicator::with_local(client, path, 1000, encryption_config).await;
let replicator =
EmbeddedReplicator::with_local(client, path, 1000, encryption_config).await;

db.replication_ctx = Some(ReplicationContext {
replicator,
Expand All @@ -146,6 +150,7 @@ impl Database {
version: Option<String>,
flags: OpenFlags,
encryption_config: Option<EncryptionConfig>,
http_request_callback: Option<crate::util::HttpRequestCallback>,
) -> Result<Database> {
use std::path::PathBuf;

Expand All @@ -160,13 +165,15 @@ impl Database {
endpoint.as_str().try_into().unwrap(),
auth_token,
version.as_deref(),
http_request_callback,
)
.unwrap();

let path = PathBuf::from(db_path);
let client = LocalClient::new(&path).await.unwrap();

let replicator = EmbeddedReplicator::with_local(client, path, 1000, encryption_config).await;
let replicator =
EmbeddedReplicator::with_local(client, path, 1000, encryption_config).await;

db.replication_ctx = Some(ReplicationContext {
replicator,
Expand Down
23 changes: 20 additions & 3 deletions libsql/src/replication/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use tower_http::{
};
use uuid::Uuid;

use crate::util::ConnectorService;
use crate::util::{ConnectorService, HttpRequestCallback};

use crate::util::box_clone_service::BoxCloneService;

Expand All @@ -46,6 +46,7 @@ impl Client {
origin: Uri,
auth_token: impl AsRef<str>,
version: Option<&str>,
http_request_callback: Option<HttpRequestCallback>,
) -> anyhow::Result<Self> {
let ver = version.unwrap_or(env!("CARGO_PKG_VERSION"));

Expand All @@ -60,7 +61,7 @@ impl Client {
let ns = split_namespace(origin.host().unwrap()).unwrap_or_else(|_| "default".to_string());
let namespace = BinaryMetadataValue::from_bytes(ns.as_bytes());

let channel = GrpcChannel::new(connector);
let channel = GrpcChannel::new(connector, http_request_callback);

let interceptor = GrpcInterceptor {
auth_token,
Expand Down Expand Up @@ -119,14 +120,30 @@ pub struct GrpcChannel {
}

impl GrpcChannel {
pub fn new(connector: ConnectorService) -> Self {
pub fn new(
connector: ConnectorService,
http_request_callback: Option<HttpRequestCallback>,
) -> Self {
let client = hyper::Client::builder().build(connector);
let client = GrpcWebClientService::new(client);

let classifier = GrpcErrorsAsFailures::new().with_success(GrpcCode::FailedPrecondition);

let svc = ServiceBuilder::new()
.layer(TraceLayer::new(SharedClassifier::new(classifier)))
.map_request(move |request: http::Request<BoxBody>| {
if let Some(cb) = &http_request_callback {
let (parts, body) = request.into_parts();
let mut req_copy = http::Request::from_parts(parts, ());
cb(&mut req_copy);

let (parts, _) = req_copy.into_parts();

http::Request::from_parts(parts, body)
} else {
request
}
})
.service(client);

let client = BoxCloneService::new(svc);
Expand Down
4 changes: 4 additions & 0 deletions libsql/src/util/http.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use super::box_clone_service::BoxCloneService;
use tokio::io::{AsyncRead, AsyncWrite};

Expand All @@ -19,3 +21,5 @@ impl hyper::client::connect::Connection for Box<dyn Socket> {

pub type ConnectorService =
BoxCloneService<http::Uri, Box<dyn Socket>, Box<dyn std::error::Error + Sync + Send + 'static>>;

pub type HttpRequestCallback = Arc<dyn Fn(&mut http::Request<()>) + Send + Sync>;
4 changes: 1 addition & 3 deletions libsql/src/util/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
cfg_replication_or_remote! {
pub mod box_clone_service;
mod http;
pub(crate) use self::http::{ConnectorService, Socket};


pub(crate) use self::http::{ConnectorService, Socket, HttpRequestCallback};
}

cfg_replication_or_remote_or_hrana! {
Expand Down

0 comments on commit cd26042

Please sign in to comment.