Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable per namespace jwt key #861

Merged
merged 13 commits into from
Jan 11, 2024
32 changes: 21 additions & 11 deletions libsql-server/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ impl Auth {
&self,
auth_header: Option<&hyper::header::HeaderValue>,
disable_namespaces: bool,
namespace_jwt_key: Option<jsonwebtoken::DecodingKey>,
) -> Result<Authenticated, AuthError> {
if self.disabled {
return Ok(Authenticated::Authorized(Authorized {
Expand Down Expand Up @@ -101,14 +102,17 @@ impl Auth {
Err(AuthError::BasicRejected)
}
}
HttpAuthHeader::Bearer(token) => self.validate_jwt(&token, disable_namespaces),
HttpAuthHeader::Bearer(token) => {
self.validate_jwt(&token, disable_namespaces, namespace_jwt_key)
}
}
}

pub fn authenticate_grpc<T>(
&self,
req: &tonic::Request<T>,
disable_namespaces: bool,
namespace_jwt_key: Option<jsonwebtoken::DecodingKey>,
) -> Result<Authenticated, Status> {
let metadata = req.metadata();

Expand All @@ -117,14 +121,15 @@ impl Auth {
.map(|v| v.to_bytes().expect("Auth should always be ASCII"))
.map(|v| HeaderValue::from_maybe_shared(v).expect("Should already be valid header"));

self.authenticate_http(auth.as_ref(), disable_namespaces)
self.authenticate_http(auth.as_ref(), disable_namespaces, namespace_jwt_key)
.map_err(Into::into)
}

pub fn authenticate_jwt(
&self,
jwt: Option<&str>,
disable_namespaces: bool,
namespace_jwt_key: Option<jsonwebtoken::DecodingKey>,
) -> Result<Authenticated, AuthError> {
if self.disabled {
return Ok(Authenticated::Authorized(Authorized {
Expand All @@ -137,16 +142,21 @@ impl Auth {
return Err(AuthError::JwtMissing);
};

self.validate_jwt(jwt, disable_namespaces)
self.validate_jwt(jwt, disable_namespaces, namespace_jwt_key)
}

fn validate_jwt(
&self,
jwt: &str,
disable_namespaces: bool,
namespace_jwt_key: Option<jsonwebtoken::DecodingKey>,
) -> Result<Authenticated, AuthError> {
let Some(jwt_key) = self.jwt_key.as_ref() else {
return Err(AuthError::JwtNotAllowed);
let jwt_key = match namespace_jwt_key.as_ref() {
Some(jwt_key) => jwt_key,
None => match self.jwt_key.as_ref() {
Some(jwt_key) => jwt_key,
None => return Err(AuthError::JwtNotAllowed),
},
};
validate_jwt(jwt_key, jwt, disable_namespaces)
}
Expand Down Expand Up @@ -368,7 +378,7 @@ mod tests {
use hyper::header::HeaderValue;

fn authenticate_http(auth: &Auth, header: &str) -> Result<Authenticated, AuthError> {
auth.authenticate_http(Some(&HeaderValue::from_str(header).unwrap()), false)
auth.authenticate_http(Some(&HeaderValue::from_str(header).unwrap()), false, None)
}

const VALID_JWT_KEY: &str = "zaMv-aFGmB7PXkjM4IrMdF6B5zCYEiEGXW3RgMjNAtc";
Expand Down Expand Up @@ -400,9 +410,9 @@ mod tests {
#[test]
fn test_default() {
let auth = Auth::default();
assert_err!(auth.authenticate_http(None, false));
assert_err!(auth.authenticate_http(None, false, None));
assert_err!(authenticate_http(&auth, "Basic d29qdGVrOnRoZWJlYXI="));
assert_err!(auth.authenticate_jwt(Some(VALID_JWT), false));
assert_err!(auth.authenticate_jwt(Some(VALID_JWT), false, None));
}

#[test]
Expand All @@ -420,7 +430,7 @@ mod tests {
assert_err!(authenticate_http(&auth, "Basic d29qdgvronrozwjlyxi="));
assert_err!(authenticate_http(&auth, "Basic d29qdGVrOnRoZWZveA=="));

assert_err!(auth.authenticate_http(None, false));
assert_err!(auth.authenticate_http(None, false, None));
assert_err!(authenticate_http(&auth, ""));
assert_err!(authenticate_http(&auth, "foobar"));
assert_err!(authenticate_http(&auth, "foo bar"));
Expand Down Expand Up @@ -457,7 +467,7 @@ mod tests {
jwt_key: Some(parse_jwt_key(VALID_JWT_KEY).unwrap()),
..Auth::default()
};
assert_ok!(auth.authenticate_jwt(Some(VALID_JWT), false));
assert_err!(auth.authenticate_jwt(Some(&VALID_JWT[..80]), false));
assert_ok!(auth.authenticate_jwt(Some(VALID_JWT), false, None));
assert_err!(auth.authenticate_jwt(Some(&VALID_JWT[..80]), false, None));
}
}
3 changes: 3 additions & 0 deletions libsql-server/src/connection/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub struct DatabaseConfig {
pub heartbeat_url: Option<Url>,
#[serde(default)]
pub bottomless_db_id: Option<String>,
#[serde(default)]
pub jwt_key: Option<String>,
}

const fn default_max_size() -> u64 {
Expand All @@ -33,6 +35,7 @@ impl Default for DatabaseConfig {
max_db_pages: default_max_size(),
heartbeat_url: None,
bottomless_db_id: None,
jwt_key: None,
}
}
}
11 changes: 8 additions & 3 deletions libsql-server/src/hrana/ws/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,14 @@ async fn handle_hello_msg<F: MakeNamespace>(
jwt: Option<String>,
) -> Result<bool> {
let hello_res = match conn.session.as_mut() {
None => session::handle_initial_hello(&conn.server, conn.version, jwt)
.map(|session| conn.session = Some(session)),
Some(session) => session::handle_repeated_hello(&conn.server, session, jwt),
None => {
session::handle_initial_hello(&conn.server, conn.version, jwt, conn.namespace.clone())
.await
.map(|session| conn.session = Some(session))
}
Some(session) => {
session::handle_repeated_hello(&conn.server, session, jwt, conn.namespace.clone()).await
}
};

match hello_res {
Expand Down
19 changes: 14 additions & 5 deletions libsql-server/src/hrana/ws/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,19 @@ pub enum ResponseError {
Batch(batch::BatchError),
}

pub(super) fn handle_initial_hello<F: MakeNamespace>(
pub(super) async fn handle_initial_hello<F: MakeNamespace>(
server: &Server<F>,
version: Version,
jwt: Option<String>,
namespace: NamespaceName,
) -> Result<Session<<F::Database as Database>::Connection>> {
let namespace_jwt_key = server
.namespaces
.with(namespace, |ns| ns.jwt_key())
.await??;
let authenticated = server
.auth
.authenticate_jwt(jwt.as_deref(), server.disable_namespaces)
.authenticate_jwt(jwt.as_deref(), server.disable_namespaces, namespace_jwt_key)
.map_err(|err| anyhow!(ResponseError::Auth { source: err }))?;

Ok(Session {
Expand All @@ -83,21 +88,25 @@ pub(super) fn handle_initial_hello<F: MakeNamespace>(
})
}

pub(super) fn handle_repeated_hello<F: MakeNamespace>(
pub(super) async fn handle_repeated_hello<F: MakeNamespace>(
server: &Server<F>,
session: &mut Session<<F::Database as Database>::Connection>,
jwt: Option<String>,
namespace: NamespaceName,
) -> Result<()> {
if session.version < Version::Hrana2 {
bail!(ProtocolError::NotSupported {
what: "Repeated hello message",
min_version: Version::Hrana2,
})
}

let namespace_jwt_key = server
.namespaces
.with(namespace, |ns| ns.jwt_key())
.await??;
session.authenticated = server
.auth
.authenticate_jwt(jwt.as_deref(), server.disable_namespaces)
.authenticate_jwt(jwt.as_deref(), server.disable_namespaces, namespace_jwt_key)
.map_err(|err| anyhow!(ResponseError::Auth { source: err }))?;
Ok(())
}
Expand Down
15 changes: 15 additions & 0 deletions libsql-server/src/http/admin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use tokio_util::io::ReaderStream;
use tower_http::trace::DefaultOnResponse;
use url::Url;

use crate::auth::parse_jwt_key;
use crate::database::Database;
use crate::error::LoadDumpError;
use crate::hrana;
Expand Down Expand Up @@ -190,6 +191,7 @@ async fn handle_get_config<M: MakeNamespace, C: Connector>(
block_reason: config.block_reason.clone(),
max_db_size: Some(max_db_size),
heartbeat_url: config.heartbeat_url.clone().map(|u| u.into()),
jwt_key: config.jwt_key.clone(),
};

Ok(Json(resp))
Expand Down Expand Up @@ -232,13 +234,19 @@ struct HttpDatabaseConfig {
max_db_size: Option<bytesize::ByteSize>,
#[serde(default)]
heartbeat_url: Option<String>,
#[serde(default)]
jwt_key: Option<String>,
}

async fn handle_post_config<M: MakeNamespace, C>(
State(app_state): State<Arc<AppState<M, C>>>,
Path(namespace): Path<String>,
Json(req): Json<HttpDatabaseConfig>,
) -> crate::Result<()> {
if let Some(jwt_key) = req.jwt_key.as_deref() {
// Check that the jwt key is correct
parse_jwt_key(jwt_key)?;
}
let store = app_state
.namespaces
.config_store(NamespaceName::from_string(namespace)?)
Expand All @@ -253,6 +261,7 @@ async fn handle_post_config<M: MakeNamespace, C>(
if let Some(url) = req.heartbeat_url {
config.heartbeat_url = Some(Url::parse(&url)?);
}
config.jwt_key = req.jwt_key;

store.store(config).await?;

Expand All @@ -265,13 +274,18 @@ struct CreateNamespaceReq {
max_db_size: Option<bytesize::ByteSize>,
heartbeat_url: Option<String>,
bottomless_db_id: Option<String>,
jwt_key: Option<String>,
}

async fn handle_create_namespace<M: MakeNamespace, C: Connector>(
State(app_state): State<Arc<AppState<M, C>>>,
Path(namespace): Path<String>,
Json(req): Json<CreateNamespaceReq>,
) -> crate::Result<()> {
if let Some(jwt_key) = req.jwt_key.as_deref() {
// Check that the jwt key is correct
parse_jwt_key(jwt_key)?;
}
let dump = match req.dump_url {
Some(ref url) => {
RestoreOption::Dump(dump_stream_from_url(url, app_state.connector.clone()).await?)
Expand All @@ -297,6 +311,7 @@ async fn handle_create_namespace<M: MakeNamespace, C: Connector>(
if let Some(url) = req.heartbeat_url {
config.heartbeat_url = Some(Url::parse(&url)?)
}
config.jwt_key = req.jwt_key;
store.store(config).await?;

Ok(())
Expand Down
14 changes: 11 additions & 3 deletions libsql-server/src/http/user/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,18 @@ where
parts: &mut Parts,
state: &AppState<M>,
) -> Result<Self, Self::Rejection> {
let ns = db_factory::namespace_from_headers(
&parts.headers,
state.disable_default_namespace,
state.disable_namespaces,
)?;
let namespace_jwt_key = state.namespaces.with(ns, |ns| ns.jwt_key()).await??;
let auth_header = parts.headers.get(hyper::header::AUTHORIZATION);
let auth = state
.auth
.authenticate_http(auth_header, state.disable_namespaces)?;
let auth = state.auth.authenticate_http(
auth_header,
state.disable_namespaces,
namespace_jwt_key,
)?;

Ok(auth)
}
Expand Down
10 changes: 9 additions & 1 deletion libsql-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ where
db_config: self.db_config.clone(),
base_path: self.path.clone(),
auth: auth.clone(),
disable_namespaces: self.disable_namespaces,
max_active_namespaces: self.max_active_namespaces,
meta_store_config: self.meta_store_config.take(),
};
Expand Down Expand Up @@ -614,6 +615,7 @@ struct Replica<C> {
db_config: DbConfig,
base_path: Arc<Path>,
auth: Arc<Auth>,
disable_namespaces: bool,
max_active_namespaces: usize,
meta_store_config: Option<MetaStoreConfig>,
}
Expand Down Expand Up @@ -649,7 +651,13 @@ impl<C: Connector> Replica<C> {
)
.await?;
let replication_service = ReplicationLogProxyService::new(channel.clone(), uri.clone());
let proxy_service = ReplicaProxyService::new(channel, uri, self.auth.clone());
let proxy_service = ReplicaProxyService::new(
channel,
uri,
namespaces.clone(),
self.auth.clone(),
self.disable_namespaces,
);

Ok((namespaces, proxy_service, replication_service))
}
Expand Down
12 changes: 12 additions & 0 deletions libsql-server/src/namespace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use tokio_util::io::StreamReader;
use tonic::transport::Channel;
use uuid::Uuid;

use crate::auth::parse_jwt_key;
use crate::auth::Authenticated;
use crate::config::MetaStoreConfig;
use crate::connection::config::DatabaseConfig;
Expand Down Expand Up @@ -808,6 +809,17 @@ impl<T: Database> Namespace<T> {
pub fn config(&self) -> Arc<DatabaseConfig> {
self.db_config_store.get()
}

pub fn jwt_key(&self) -> crate::Result<Option<jsonwebtoken::DecodingKey>> {
let config = self.db_config_store.get();
if let Some(jwt_key) = config.jwt_key.as_deref() {
Ok(Some(
parse_jwt_key(jwt_key).context("Could not parse JWT decoding key")?,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be best if you decoded the key only once, when the namespace is loaded. As per jsonwebtoken Documentation:

All the different kind of keys we can use to decode a JWT. This key can be re-used so make sure you only initialize it once if you can for better performance.

Also, consider wrapping it in an Arc: cloning the key is not cheap.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about it but didn't have a good place to store the key. Let's do this in a follow up as this PR is hanging for some time now.

))
} else {
Ok(None)
}
}
}

pub struct ReplicaNamespaceConfig {
Expand Down
Loading
Loading