diff --git a/libsql-server/src/auth.rs b/libsql-server/src/auth.rs index 71ec306b43..d17e6ed379 100644 --- a/libsql-server/src/auth.rs +++ b/libsql-server/src/auth.rs @@ -71,6 +71,7 @@ impl Auth { &self, auth_header: Option<&hyper::header::HeaderValue>, disable_namespaces: bool, + namespace_jwt_key: Option, ) -> Result { if self.disabled { return Ok(Authenticated::Authorized(Authorized { @@ -101,7 +102,9 @@ impl Auth { Err(AuthError::BasicRejected) } } - HttpAuthHeader::Bearer(token) => self.validate_jwt(&token, disable_namespaces), + HttpAuthHeader::Bearer(token) => { + self.validate_jwt(&token, disable_namespaces, namespace_jwt_key) + } } } @@ -109,6 +112,7 @@ impl Auth { &self, req: &tonic::Request, disable_namespaces: bool, + namespace_jwt_key: Option, ) -> Result { let metadata = req.metadata(); @@ -117,7 +121,7 @@ impl Auth { .map(|v| v.to_bytes().expect("Auth should always be ASCII")) .map(|v| HeaderValue::from_maybe_shared(v).expect("Should already be valid header")); - self.authenticate_http(auth.as_ref(), disable_namespaces) + self.authenticate_http(auth.as_ref(), disable_namespaces, namespace_jwt_key) .map_err(Into::into) } @@ -125,6 +129,7 @@ impl Auth { &self, jwt: Option<&str>, disable_namespaces: bool, + namespace_jwt_key: Option, ) -> Result { if self.disabled { return Ok(Authenticated::Authorized(Authorized { @@ -137,16 +142,21 @@ impl Auth { return Err(AuthError::JwtMissing); }; - self.validate_jwt(jwt, disable_namespaces) + self.validate_jwt(jwt, disable_namespaces, namespace_jwt_key) } fn validate_jwt( &self, jwt: &str, disable_namespaces: bool, + namespace_jwt_key: Option, ) -> Result { - let Some(jwt_key) = self.jwt_key.as_ref() else { - return Err(AuthError::JwtNotAllowed); + let jwt_key = match namespace_jwt_key.as_ref() { + Some(jwt_key) => jwt_key, + None => match self.jwt_key.as_ref() { + Some(jwt_key) => jwt_key, + None => return Err(AuthError::JwtNotAllowed), + }, }; validate_jwt(jwt_key, jwt, disable_namespaces) } @@ -368,7 +378,7 @@ mod tests { use hyper::header::HeaderValue; fn authenticate_http(auth: &Auth, header: &str) -> Result { - auth.authenticate_http(Some(&HeaderValue::from_str(header).unwrap()), false) + auth.authenticate_http(Some(&HeaderValue::from_str(header).unwrap()), false, None) } const VALID_JWT_KEY: &str = "zaMv-aFGmB7PXkjM4IrMdF6B5zCYEiEGXW3RgMjNAtc"; @@ -400,9 +410,9 @@ mod tests { #[test] fn test_default() { let auth = Auth::default(); - assert_err!(auth.authenticate_http(None, false)); + assert_err!(auth.authenticate_http(None, false, None)); assert_err!(authenticate_http(&auth, "Basic d29qdGVrOnRoZWJlYXI=")); - assert_err!(auth.authenticate_jwt(Some(VALID_JWT), false)); + assert_err!(auth.authenticate_jwt(Some(VALID_JWT), false, None)); } #[test] @@ -420,7 +430,7 @@ mod tests { assert_err!(authenticate_http(&auth, "Basic d29qdgvronrozwjlyxi=")); assert_err!(authenticate_http(&auth, "Basic d29qdGVrOnRoZWZveA==")); - assert_err!(auth.authenticate_http(None, false)); + assert_err!(auth.authenticate_http(None, false, None)); assert_err!(authenticate_http(&auth, "")); assert_err!(authenticate_http(&auth, "foobar")); assert_err!(authenticate_http(&auth, "foo bar")); @@ -457,7 +467,7 @@ mod tests { jwt_key: Some(parse_jwt_key(VALID_JWT_KEY).unwrap()), ..Auth::default() }; - assert_ok!(auth.authenticate_jwt(Some(VALID_JWT), false)); - assert_err!(auth.authenticate_jwt(Some(&VALID_JWT[..80]), false)); + assert_ok!(auth.authenticate_jwt(Some(VALID_JWT), false, None)); + assert_err!(auth.authenticate_jwt(Some(&VALID_JWT[..80]), false, None)); } } diff --git a/libsql-server/src/connection/config.rs b/libsql-server/src/connection/config.rs index 239ccefcca..51b856669d 100644 --- a/libsql-server/src/connection/config.rs +++ b/libsql-server/src/connection/config.rs @@ -18,6 +18,8 @@ pub struct DatabaseConfig { pub heartbeat_url: Option, #[serde(default)] pub bottomless_db_id: Option, + #[serde(default)] + pub jwt_key: Option, } const fn default_max_size() -> u64 { @@ -33,6 +35,7 @@ impl Default for DatabaseConfig { max_db_pages: default_max_size(), heartbeat_url: None, bottomless_db_id: None, + jwt_key: None, } } } diff --git a/libsql-server/src/hrana/ws/conn.rs b/libsql-server/src/hrana/ws/conn.rs index 0188959b8a..73c90605e0 100644 --- a/libsql-server/src/hrana/ws/conn.rs +++ b/libsql-server/src/hrana/ws/conn.rs @@ -217,9 +217,14 @@ async fn handle_hello_msg( jwt: Option, ) -> Result { let hello_res = match conn.session.as_mut() { - None => session::handle_initial_hello(&conn.server, conn.version, jwt) - .map(|session| conn.session = Some(session)), - Some(session) => session::handle_repeated_hello(&conn.server, session, jwt), + None => { + session::handle_initial_hello(&conn.server, conn.version, jwt, conn.namespace.clone()) + .await + .map(|session| conn.session = Some(session)) + } + Some(session) => { + session::handle_repeated_hello(&conn.server, session, jwt, conn.namespace.clone()).await + } }; match hello_res { diff --git a/libsql-server/src/hrana/ws/session.rs b/libsql-server/src/hrana/ws/session.rs index 9d9ec5a3b0..9074a45bfa 100644 --- a/libsql-server/src/hrana/ws/session.rs +++ b/libsql-server/src/hrana/ws/session.rs @@ -64,14 +64,19 @@ pub enum ResponseError { Batch(batch::BatchError), } -pub(super) fn handle_initial_hello( +pub(super) async fn handle_initial_hello( server: &Server, version: Version, jwt: Option, + namespace: NamespaceName, ) -> Result::Connection>> { + let namespace_jwt_key = server + .namespaces + .with(namespace, |ns| ns.jwt_key()) + .await??; let authenticated = server .auth - .authenticate_jwt(jwt.as_deref(), server.disable_namespaces) + .authenticate_jwt(jwt.as_deref(), server.disable_namespaces, namespace_jwt_key) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(Session { @@ -83,10 +88,11 @@ pub(super) fn handle_initial_hello( }) } -pub(super) fn handle_repeated_hello( +pub(super) async fn handle_repeated_hello( server: &Server, session: &mut Session<::Connection>, jwt: Option, + namespace: NamespaceName, ) -> Result<()> { if session.version < Version::Hrana2 { bail!(ProtocolError::NotSupported { @@ -94,10 +100,13 @@ pub(super) fn handle_repeated_hello( min_version: Version::Hrana2, }) } - + let namespace_jwt_key = server + .namespaces + .with(namespace, |ns| ns.jwt_key()) + .await??; session.authenticated = server .auth - .authenticate_jwt(jwt.as_deref(), server.disable_namespaces) + .authenticate_jwt(jwt.as_deref(), server.disable_namespaces, namespace_jwt_key) .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; Ok(()) } diff --git a/libsql-server/src/http/admin/mod.rs b/libsql-server/src/http/admin/mod.rs index 90abe59334..048b85ccdd 100644 --- a/libsql-server/src/http/admin/mod.rs +++ b/libsql-server/src/http/admin/mod.rs @@ -18,6 +18,7 @@ use tokio_util::io::ReaderStream; use tower_http::trace::DefaultOnResponse; use url::Url; +use crate::auth::parse_jwt_key; use crate::database::Database; use crate::error::LoadDumpError; use crate::hrana; @@ -190,6 +191,7 @@ async fn handle_get_config( block_reason: config.block_reason.clone(), max_db_size: Some(max_db_size), heartbeat_url: config.heartbeat_url.clone().map(|u| u.into()), + jwt_key: config.jwt_key.clone(), }; Ok(Json(resp)) @@ -232,6 +234,8 @@ struct HttpDatabaseConfig { max_db_size: Option, #[serde(default)] heartbeat_url: Option, + #[serde(default)] + jwt_key: Option, } async fn handle_post_config( @@ -239,6 +243,10 @@ async fn handle_post_config( Path(namespace): Path, Json(req): Json, ) -> crate::Result<()> { + if let Some(jwt_key) = req.jwt_key.as_deref() { + // Check that the jwt key is correct + parse_jwt_key(jwt_key)?; + } let store = app_state .namespaces .config_store(NamespaceName::from_string(namespace)?) @@ -253,6 +261,7 @@ async fn handle_post_config( if let Some(url) = req.heartbeat_url { config.heartbeat_url = Some(Url::parse(&url)?); } + config.jwt_key = req.jwt_key; store.store(config).await?; @@ -265,6 +274,7 @@ struct CreateNamespaceReq { max_db_size: Option, heartbeat_url: Option, bottomless_db_id: Option, + jwt_key: Option, } async fn handle_create_namespace( @@ -272,6 +282,10 @@ async fn handle_create_namespace( Path(namespace): Path, Json(req): Json, ) -> crate::Result<()> { + if let Some(jwt_key) = req.jwt_key.as_deref() { + // Check that the jwt key is correct + parse_jwt_key(jwt_key)?; + } let dump = match req.dump_url { Some(ref url) => { RestoreOption::Dump(dump_stream_from_url(url, app_state.connector.clone()).await?) @@ -297,6 +311,7 @@ async fn handle_create_namespace( if let Some(url) = req.heartbeat_url { config.heartbeat_url = Some(Url::parse(&url)?) } + config.jwt_key = req.jwt_key; store.store(config).await?; Ok(()) diff --git a/libsql-server/src/http/user/mod.rs b/libsql-server/src/http/user/mod.rs index b3be53225f..6c9ec616f7 100644 --- a/libsql-server/src/http/user/mod.rs +++ b/libsql-server/src/http/user/mod.rs @@ -479,10 +479,18 @@ where parts: &mut Parts, state: &AppState, ) -> Result { + let ns = db_factory::namespace_from_headers( + &parts.headers, + state.disable_default_namespace, + state.disable_namespaces, + )?; + let namespace_jwt_key = state.namespaces.with(ns, |ns| ns.jwt_key()).await??; let auth_header = parts.headers.get(hyper::header::AUTHORIZATION); - let auth = state - .auth - .authenticate_http(auth_header, state.disable_namespaces)?; + let auth = state.auth.authenticate_http( + auth_header, + state.disable_namespaces, + namespace_jwt_key, + )?; Ok(auth) } diff --git a/libsql-server/src/lib.rs b/libsql-server/src/lib.rs index f0f93d887b..8ef85f05d0 100644 --- a/libsql-server/src/lib.rs +++ b/libsql-server/src/lib.rs @@ -390,6 +390,7 @@ where db_config: self.db_config.clone(), base_path: self.path.clone(), auth: auth.clone(), + disable_namespaces: self.disable_namespaces, max_active_namespaces: self.max_active_namespaces, meta_store_config: self.meta_store_config.take(), }; @@ -614,6 +615,7 @@ struct Replica { db_config: DbConfig, base_path: Arc, auth: Arc, + disable_namespaces: bool, max_active_namespaces: usize, meta_store_config: Option, } @@ -649,7 +651,13 @@ impl Replica { ) .await?; let replication_service = ReplicationLogProxyService::new(channel.clone(), uri.clone()); - let proxy_service = ReplicaProxyService::new(channel, uri, self.auth.clone()); + let proxy_service = ReplicaProxyService::new( + channel, + uri, + namespaces.clone(), + self.auth.clone(), + self.disable_namespaces, + ); Ok((namespaces, proxy_service, replication_service)) } diff --git a/libsql-server/src/namespace/mod.rs b/libsql-server/src/namespace/mod.rs index e7e23184e2..2a61945ae0 100644 --- a/libsql-server/src/namespace/mod.rs +++ b/libsql-server/src/namespace/mod.rs @@ -31,6 +31,7 @@ use tokio_util::io::StreamReader; use tonic::transport::Channel; use uuid::Uuid; +use crate::auth::parse_jwt_key; use crate::auth::Authenticated; use crate::config::MetaStoreConfig; use crate::connection::config::DatabaseConfig; @@ -808,6 +809,17 @@ impl Namespace { pub fn config(&self) -> Arc { self.db_config_store.get() } + + pub fn jwt_key(&self) -> crate::Result> { + let config = self.db_config_store.get(); + if let Some(jwt_key) = config.jwt_key.as_deref() { + Ok(Some( + parse_jwt_key(jwt_key).context("Could not parse JWT decoding key")?, + )) + } else { + Ok(None) + } + } } pub struct ReplicaNamespaceConfig { diff --git a/libsql-server/src/rpc/proxy.rs b/libsql-server/src/rpc/proxy.rs index 23043fbd3b..8a7f9019a6 100644 --- a/libsql-server/src/rpc/proxy.rs +++ b/libsql-server/src/rpc/proxy.rs @@ -18,7 +18,7 @@ use uuid::Uuid; use crate::auth::{Auth, Authenticated}; use crate::connection::Connection; use crate::database::{Database, PrimaryConnection}; -use crate::namespace::{NamespaceStore, PrimaryNamespaceMaker}; +use crate::namespace::{NamespaceName, NamespaceStore, PrimaryNamespaceMaker}; use crate::query_result_builder::{ Column, QueryBuilderConfig, QueryResultBuilder, QueryResultBuilderError, }; @@ -302,6 +302,31 @@ impl ProxyService { pub fn clients(&self) -> Arc>>> { self.clients.clone() } + + async fn auth( + &self, + req: &mut tonic::Request, + namespace: NamespaceName, + ) -> Result { + let namespace_jwt_key = self.namespaces.with(namespace, |ns| ns.jwt_key()).await; + let namespace_jwt_key = match namespace_jwt_key { + Ok(Ok(jwt_key)) => Ok(jwt_key), + Err(crate::error::Error::NamespaceDoesntExist(_)) => Ok(None), + Err(e) => Err(tonic::Status::internal(format!( + "Error fetching jwt key for a namespace: {}", + e + ))), + Ok(Err(e)) => Err(tonic::Status::internal(format!( + "Error fetching jwt key for a namespace: {}", + e + ))), + }?; + Ok(if let Some(auth) = &self.auth { + auth.authenticate_grpc(&req, self.disable_namespaces, namespace_jwt_key)? + } else { + Authenticated::from_proxy_grpc_request(&req, self.disable_namespaces)? + }) + } } #[derive(Debug, Default)] @@ -491,15 +516,11 @@ impl Proxy for ProxyService { async fn stream_exec( &self, - req: tonic::Request>, + mut req: tonic::Request>, ) -> Result, tonic::Status> { - let auth = if let Some(auth) = &self.auth { - auth.authenticate_grpc(&req, self.disable_namespaces)? - } else { - Authenticated::from_proxy_grpc_request(&req, self.disable_namespaces)? - }; - let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + let auth = self.auth(&mut req, namespace.clone()).await?; + let (connection_maker, _new_frame_notifier) = self .namespaces .with(namespace, |ns| { @@ -531,14 +552,10 @@ impl Proxy for ProxyService { async fn execute( &self, - req: tonic::Request, + mut req: tonic::Request, ) -> Result, tonic::Status> { - let auth = if let Some(auth) = &self.auth { - auth.authenticate_grpc(&req, self.disable_namespaces)? - } else { - Authenticated::from_proxy_grpc_request(&req, self.disable_namespaces)? - }; let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + let auth = self.auth(&mut req, namespace.clone()).await?; let req = req.into_inner(); let pgm = crate::connection::program::Program::try_from(req.pgm.unwrap()) .map_err(|e| tonic::Status::new(tonic::Code::InvalidArgument, e.to_string()))?; @@ -602,16 +619,12 @@ impl Proxy for ProxyService { async fn describe( &self, - msg: tonic::Request, + mut msg: tonic::Request, ) -> Result, tonic::Status> { - let auth = if let Some(auth) = &self.auth { - auth.authenticate_grpc(&msg, self.disable_namespaces)? - } else { - Authenticated::from_proxy_grpc_request(&msg, self.disable_namespaces)? - }; + let namespace = super::extract_namespace(self.disable_namespaces, &msg)?; + let auth = self.auth(&mut msg, namespace.clone()).await?; // FIXME: copypasta from execute(), creatively extract to a helper function - let namespace = super::extract_namespace(self.disable_namespaces, &msg)?; let lock = self.clients.upgradable_read().await; let (connection_maker, _new_frame_notifier) = self .namespaces diff --git a/libsql-server/src/rpc/replica_proxy.rs b/libsql-server/src/rpc/replica_proxy.rs index 6a6951c535..6c5aeb92a9 100644 --- a/libsql-server/src/rpc/replica_proxy.rs +++ b/libsql-server/src/rpc/replica_proxy.rs @@ -8,25 +8,58 @@ use libsql_replication::rpc::proxy::{ use tokio_stream::StreamExt; use tonic::{transport::Channel, Request, Status}; -use crate::auth::Auth; +use crate::{ + auth::Auth, + namespace::{NamespaceStore, ReplicaNamespaceMaker}, +}; pub struct ReplicaProxyService { client: ProxyClient, auth: Arc, + disable_namespaces: bool, + namespaces: NamespaceStore, } impl ReplicaProxyService { - pub fn new(channel: Channel, uri: Uri, auth: Arc) -> Self { + pub fn new( + channel: Channel, + uri: Uri, + namespaces: NamespaceStore, + auth: Arc, + disable_namespaces: bool, + ) -> Self { let client = ProxyClient::with_origin(channel, uri); - Self { client, auth } + Self { + client, + auth, + disable_namespaces, + namespaces, + } } - fn do_auth(&self, req: &mut Request) -> Result<(), Status> { - let authenticated = self.auth.authenticate_grpc(req, false)?; - - authenticated.upgrade_grpc_request(req); - - Ok(()) + async fn do_auth(&self, req: &mut Request) -> Result<(), Status> { + let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + let namespace_jwt_key = self.namespaces.with(namespace, |ns| ns.jwt_key()).await; + match namespace_jwt_key { + Ok(Ok(jwt_key)) => { + let authenticated = self.auth.authenticate_grpc(req, false, jwt_key)?; + authenticated.upgrade_grpc_request(req); + Ok(()) + } + Err(crate::error::Error::NamespaceDoesntExist(_)) => { + let authenticated = self.auth.authenticate_grpc(req, false, None)?; + authenticated.upgrade_grpc_request(req); + Ok(()) + } + Err(e) => Err(Status::internal(format!( + "Error fetching jwt key for a namespace: {}", + e + ))), + Ok(Err(e)) => Err(Status::internal(format!( + "Error fetching jwt key for a namespace: {}", + e + ))), + } } } @@ -54,7 +87,7 @@ impl Proxy for ReplicaProxyService { } }; let mut req = tonic::Request::from_parts(meta, ext, stream); - self.do_auth(&mut req)?; + self.do_auth(&mut req).await?; let mut client = self.client.clone(); client.stream_exec(req).await } @@ -64,7 +97,7 @@ impl Proxy for ReplicaProxyService { mut req: tonic::Request, ) -> Result, tonic::Status> { tracing::debug!("execute"); - self.do_auth(&mut req)?; + self.do_auth(&mut req).await?; let mut client = self.client.clone(); client.execute(req).await @@ -75,7 +108,7 @@ impl Proxy for ReplicaProxyService { &self, mut msg: tonic::Request, ) -> Result, tonic::Status> { - self.do_auth(&mut msg)?; + self.do_auth(&mut msg).await?; let mut client = self.client.clone(); client.disconnect(msg).await @@ -85,7 +118,7 @@ impl Proxy for ReplicaProxyService { &self, mut req: tonic::Request, ) -> Result, tonic::Status> { - self.do_auth(&mut req)?; + self.do_auth(&mut req).await?; let mut client = self.client.clone(); client.describe(req).await diff --git a/libsql-server/src/rpc/replication_log.rs b/libsql-server/src/rpc/replication_log.rs index 8002c6e06a..d217c3c7fd 100644 --- a/libsql-server/src/rpc/replication_log.rs +++ b/libsql-server/src/rpc/replication_log.rs @@ -58,12 +58,36 @@ impl ReplicationLogService { } } - fn authenticate(&self, req: &tonic::Request) -> Result<(), Status> { - if let Some(auth) = &self.auth { - let _ = auth.authenticate_grpc(req, self.disable_namespaces)?; + async fn authenticate( + &self, + req: &tonic::Request, + namespace: NamespaceName, + ) -> Result<(), Status> { + let namespace_jwt_key = self.namespaces.with(namespace, |ns| ns.jwt_key()).await; + match namespace_jwt_key { + Ok(Ok(jwt_key)) => { + if let Some(auth) = &self.auth { + auth.authenticate_grpc(req, self.disable_namespaces, jwt_key)?; + } + Ok(()) + } + Err(e) => match e.as_ref() { + crate::error::Error::NamespaceDoesntExist(_) => { + if let Some(auth) = &self.auth { + auth.authenticate_grpc(req, self.disable_namespaces, None)?; + } + Ok(()) + } + _ => Err(Status::internal(format!( + "Error fetching jwt key for a namespace: {}", + e + ))), + }, + Ok(Err(e)) => Err(Status::internal(format!( + "Error fetching jwt key for a namespace: {}", + e + ))), } - - Ok(()) } fn verify_session_token(&self, req: &tonic::Request) -> Result<(), Status> { @@ -159,11 +183,11 @@ impl ReplicationLog for ReplicationLogService { &self, req: tonic::Request, ) -> Result, Status> { - self.authenticate(&req)?; - self.verify_session_token(&req)?; - let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + self.authenticate(&req, namespace.clone()).await?; + self.verify_session_token(&req)?; + let req = req.into_inner(); let logger = self .namespaces @@ -191,9 +215,9 @@ impl ReplicationLog for ReplicationLogService { &self, req: tonic::Request, ) -> Result, Status> { - self.authenticate(&req)?; - self.verify_session_token(&req)?; let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + self.authenticate(&req, namespace.clone()).await?; + self.verify_session_token(&req)?; let req = req.into_inner(); let logger = self @@ -224,8 +248,8 @@ impl ReplicationLog for ReplicationLogService { &self, req: tonic::Request, ) -> Result, Status> { - self.authenticate(&req)?; let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + self.authenticate(&req, namespace.clone()).await?; // legacy support if req.get_ref().handshake_version.is_none() { @@ -278,9 +302,9 @@ impl ReplicationLog for ReplicationLogService { &self, req: tonic::Request, ) -> Result, Status> { - self.authenticate(&req)?; - self.verify_session_token(&req)?; let namespace = super::extract_namespace(self.disable_namespaces, &req)?; + self.authenticate(&req, namespace.clone()).await?; + self.verify_session_token(&req)?; let req = req.into_inner(); let logger = self