Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for TLS in postgres/tokio-postgres using native-tls #353

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions refinery/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ edition = "2018"
default = ["toml"]
rusqlite-bundled = ["refinery-core/rusqlite-bundled"]
rusqlite = ["refinery-core/rusqlite"]
postgres = ["refinery-core/postgres"]
postgres = ["refinery-core/postgres", "refinery-core/postgres-native-tls", "refinery-core/native-tls"]
mysql = ["refinery-core/mysql"]
tokio-postgres = ["refinery-core/tokio-postgres"]
tokio-postgres = ["refinery-core/tokio-postgres", "refinery-core/postgres-native-tls", "refinery-core/native-tls"]
mysql_async = ["refinery-core/mysql_async"]
tiberius = ["refinery-core/tiberius"]
tiberius-config = ["refinery-core/tiberius", "refinery-core/tiberius-config"]
Expand Down
2 changes: 1 addition & 1 deletion refinery_cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ path = "src/main.rs"

[features]
default = ["mysql", "postgresql", "sqlite-bundled", "mssql"]
postgresql = ["refinery-core/postgres"]
postgresql = ["refinery-core/postgres", "refinery-core/postgres-native-tls", "refinery-core/native-tls"]
mysql = ["refinery-core/mysql"]
sqlite = ["refinery-core/rusqlite"]
sqlite-bundled = ["sqlite", "refinery-core/rusqlite-bundled"]
Expand Down
2 changes: 2 additions & 0 deletions refinery_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ walkdir = "2.3.1"
# allow multiple versions of the same dependency if API is similar
rusqlite = { version = ">= 0.23, <= 0.33", optional = true }
postgres = { version = ">=0.17, <= 0.19", optional = true }
native-tls = { version = "0.2", optional = true }
postgres-native-tls = { version = "0.5", optional = true}
tokio-postgres = { version = ">= 0.5, <= 0.7", optional = true }
mysql = { version = ">= 21.0.0, <= 25", optional = true, default-features = false, features = ["minimal"] }
mysql_async = { version = ">= 0.28, <= 0.35", optional = true, default-features = false, features = ["minimal"] }
Expand Down
48 changes: 43 additions & 5 deletions refinery_core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::Error;
use std::convert::TryFrom;
use std::path::PathBuf;
use std::str::FromStr;
use std::{borrow::Cow, collections::HashMap};
use url::Url;

// refinery config file used by migrate_from_config if migration from a Config struct is preferred instead of using the macros
Expand Down Expand Up @@ -34,6 +35,7 @@ impl Config {
db_user: None,
db_pass: None,
db_name: None,
use_tls: None,
#[cfg(feature = "tiberius-config")]
trust_cert: false,
},
Expand Down Expand Up @@ -139,6 +141,10 @@ impl Config {
self.main.db_port.as_deref()
}

pub fn use_tls(&self) -> Option<bool> {
self.main.use_tls
}

pub fn set_db_user(self, db_user: &str) -> Config {
Config {
main: Main {
Expand Down Expand Up @@ -203,13 +209,12 @@ impl TryFrom<Url> for Config {
}
};

let query_params = url
.query_pairs()
.collect::<HashMap<Cow<'_, str>, Cow<'_, str>>>();

cfg_if::cfg_if! {
if #[cfg(feature = "tiberius-config")] {
use std::{borrow::Cow, collections::HashMap};
let query_params = url
.query_pairs()
.collect::<HashMap< Cow<'_, str>, Cow<'_, str>>>();

let trust_cert = query_params.
get("trust_cert")
.unwrap_or(&Cow::Borrowed("false"))
Expand All @@ -223,6 +228,20 @@ impl TryFrom<Url> for Config {
}
}

let use_tls = match query_params
.get("sslmode")
.unwrap_or(&Cow::Borrowed("disable"))
{
&Cow::Borrowed("disable") => false,
&Cow::Borrowed("require") => true,
_ => {
return Err(Error::new(
Kind::ConfigError("Invalid sslmode value, please use disable/require".into()),
None,
))
}
};

Ok(Self {
main: Main {
db_type,
Expand All @@ -238,6 +257,7 @@ impl TryFrom<Url> for Config {
db_user: Some(url.username().to_string()),
db_pass: url.password().map(|r| r.to_string()),
db_name: Some(url.path().trim_start_matches('/').to_string()),
use_tls: Some(use_tls),
#[cfg(feature = "tiberius-config")]
trust_cert,
},
Expand Down Expand Up @@ -270,6 +290,7 @@ struct Main {
db_user: Option<String>,
db_pass: Option<String>,
db_name: Option<String>,
use_tls: Option<bool>,
#[cfg(feature = "tiberius-config")]
#[serde(default)]
trust_cert: bool,
Expand Down Expand Up @@ -453,6 +474,23 @@ mod tests {
);
}

#[test]
fn builds_from_sslmode_str() {
let config =
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=disable")
.unwrap();
assert!(config.use_tls().is_some());
assert!(!config.use_tls().unwrap());
let config =
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=require")
.unwrap();
assert!(config.use_tls().is_some());
assert!(config.use_tls().unwrap());
let config =
Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=invalidvalue");
assert!(config.is_err());
}

#[test]
fn builds_db_env_var_failure() {
std::env::set_var("DATABASE_URL", "this_is_not_a_url");
Expand Down
37 changes: 29 additions & 8 deletions refinery_core/src/drivers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,16 @@ macro_rules! with_connection {
cfg_if::cfg_if! {
if #[cfg(feature = "postgres")] {
let path = build_db_url("postgresql", &$config);
let conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;

let conn;
if $config.use_tls().is_some() && $config.use_tls().unwrap() {
let connector = native_tls::TlsConnector::new().unwrap();
let connector = postgres_native_tls::MakeTlsConnector::new(connector);
conn = postgres::Client::connect(path.as_str(), connector).migration_err("could not connect to database", None)?;
} else {
conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;
}

$op(conn)
} else {
panic!("tried to migrate from config for a postgresql database, but feature postgres not enabled!");
Expand Down Expand Up @@ -129,13 +138,25 @@ macro_rules! with_connection_async {
cfg_if::cfg_if! {
if #[cfg(feature = "tokio-postgres")] {
let path = build_db_url("postgresql", $config);
let (client, connection ) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
$op(client).await
if $config.use_tls().is_some() && $config.use_tls().unwrap() {
let connector = native_tls::TlsConnector::new().unwrap();
let connector = postgres_native_tls::MakeTlsConnector::new(connector);
let (client, connection) = tokio_postgres::connect(path.as_str(), connector).await.migration_err("could not connect to database", None)?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
$op(client).await
} else {
let (client, connection) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
$op(client).await
}
} else {
panic!("tried to migrate async from config for a postgresql database, but tokio-postgres was not enabled!");
}
Expand Down
Loading