diff --git a/refinery/Cargo.toml b/refinery/Cargo.toml index fce329ec..fbf22f5a 100644 --- a/refinery/Cargo.toml +++ b/refinery/Cargo.toml @@ -16,9 +16,9 @@ edition = "2018" default = ["toml"] rusqlite-bundled = ["refinery-core/rusqlite-bundled"] rusqlite = ["refinery-core/rusqlite"] -postgres = ["refinery-core/postgres"] +postgres = ["refinery-core/postgres", "refinery-core/postgres-native-tls", "refinery-core/native-tls"] mysql = ["refinery-core/mysql"] -tokio-postgres = ["refinery-core/tokio-postgres"] +tokio-postgres = ["refinery-core/tokio-postgres", "refinery-core/postgres-native-tls", "refinery-core/native-tls"] mysql_async = ["refinery-core/mysql_async"] tiberius = ["refinery-core/tiberius"] tiberius-config = ["refinery-core/tiberius", "refinery-core/tiberius-config"] diff --git a/refinery_cli/Cargo.toml b/refinery_cli/Cargo.toml index e6bca7d8..4a313acc 100644 --- a/refinery_cli/Cargo.toml +++ b/refinery_cli/Cargo.toml @@ -16,7 +16,7 @@ path = "src/main.rs" [features] default = ["mysql", "postgresql", "sqlite-bundled", "mssql"] -postgresql = ["refinery-core/postgres"] +postgresql = ["refinery-core/postgres", "refinery-core/postgres-native-tls", "refinery-core/native-tls"] mysql = ["refinery-core/mysql"] sqlite = ["refinery-core/rusqlite"] sqlite-bundled = ["sqlite", "refinery-core/rusqlite-bundled"] diff --git a/refinery_core/Cargo.toml b/refinery_core/Cargo.toml index 9cde9c17..9ca8e962 100644 --- a/refinery_core/Cargo.toml +++ b/refinery_core/Cargo.toml @@ -31,6 +31,8 @@ walkdir = "2.3.1" # allow multiple versions of the same dependency if API is similar rusqlite = { version = ">= 0.23, <= 0.33", optional = true } postgres = { version = ">=0.17, <= 0.19", optional = true } +native-tls = { version = "0.2", optional = true } +postgres-native-tls = { version = "0.5", optional = true} tokio-postgres = { version = ">= 0.5, <= 0.7", optional = true } mysql = { version = ">= 21.0.0, <= 25", optional = true, default-features = false, features = ["minimal"] } mysql_async = { version = ">= 0.28, <= 0.35", optional = true, default-features = false, features = ["minimal"] } diff --git a/refinery_core/src/config.rs b/refinery_core/src/config.rs index 7eba833f..1a4f1c65 100644 --- a/refinery_core/src/config.rs +++ b/refinery_core/src/config.rs @@ -3,6 +3,7 @@ use crate::Error; use std::convert::TryFrom; use std::path::PathBuf; use std::str::FromStr; +use std::{borrow::Cow, collections::HashMap}; use url::Url; // refinery config file used by migrate_from_config if migration from a Config struct is preferred instead of using the macros @@ -34,6 +35,7 @@ impl Config { db_user: None, db_pass: None, db_name: None, + use_tls: None, #[cfg(feature = "tiberius-config")] trust_cert: false, }, @@ -139,6 +141,10 @@ impl Config { self.main.db_port.as_deref() } + pub fn use_tls(&self) -> Option { + self.main.use_tls + } + pub fn set_db_user(self, db_user: &str) -> Config { Config { main: Main { @@ -203,13 +209,12 @@ impl TryFrom for Config { } }; + let query_params = url + .query_pairs() + .collect::, Cow<'_, str>>>(); + cfg_if::cfg_if! { if #[cfg(feature = "tiberius-config")] { - use std::{borrow::Cow, collections::HashMap}; - let query_params = url - .query_pairs() - .collect::, Cow<'_, str>>>(); - let trust_cert = query_params. get("trust_cert") .unwrap_or(&Cow::Borrowed("false")) @@ -223,6 +228,20 @@ impl TryFrom for Config { } } + let use_tls = match query_params + .get("sslmode") + .unwrap_or(&Cow::Borrowed("disable")) + { + &Cow::Borrowed("disable") => false, + &Cow::Borrowed("require") => true, + _ => { + return Err(Error::new( + Kind::ConfigError("Invalid sslmode value, please use disable/require".into()), + None, + )) + } + }; + Ok(Self { main: Main { db_type, @@ -238,6 +257,7 @@ impl TryFrom for Config { db_user: Some(url.username().to_string()), db_pass: url.password().map(|r| r.to_string()), db_name: Some(url.path().trim_start_matches('/').to_string()), + use_tls: Some(use_tls), #[cfg(feature = "tiberius-config")] trust_cert, }, @@ -270,6 +290,7 @@ struct Main { db_user: Option, db_pass: Option, db_name: Option, + use_tls: Option, #[cfg(feature = "tiberius-config")] #[serde(default)] trust_cert: bool, @@ -453,6 +474,23 @@ mod tests { ); } + #[test] + fn builds_from_sslmode_str() { + let config = + Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=disable") + .unwrap(); + assert!(config.use_tls().is_some()); + assert!(!config.use_tls().unwrap()); + let config = + Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=require") + .unwrap(); + assert!(config.use_tls().is_some()); + assert!(config.use_tls().unwrap()); + let config = + Config::from_str("postgres://root:1234@localhost:5432/refinery?sslmode=invalidvalue"); + assert!(config.is_err()); + } + #[test] fn builds_db_env_var_failure() { std::env::set_var("DATABASE_URL", "this_is_not_a_url"); diff --git a/refinery_core/src/drivers/config.rs b/refinery_core/src/drivers/config.rs index bc95b572..9fdd0975 100644 --- a/refinery_core/src/drivers/config.rs +++ b/refinery_core/src/drivers/config.rs @@ -88,7 +88,16 @@ macro_rules! with_connection { cfg_if::cfg_if! { if #[cfg(feature = "postgres")] { let path = build_db_url("postgresql", &$config); - let conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?; + + let conn; + if $config.use_tls().is_some() && $config.use_tls().unwrap() { + let connector = native_tls::TlsConnector::new().unwrap(); + let connector = postgres_native_tls::MakeTlsConnector::new(connector); + conn = postgres::Client::connect(path.as_str(), connector).migration_err("could not connect to database", None)?; + } else { + conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?; + } + $op(conn) } else { panic!("tried to migrate from config for a postgresql database, but feature postgres not enabled!"); @@ -129,13 +138,25 @@ macro_rules! with_connection_async { cfg_if::cfg_if! { if #[cfg(feature = "tokio-postgres")] { let path = build_db_url("postgresql", $config); - let (client, connection ) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?; - tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("connection error: {}", e); - } - }); - $op(client).await + if $config.use_tls().is_some() && $config.use_tls().unwrap() { + let connector = native_tls::TlsConnector::new().unwrap(); + let connector = postgres_native_tls::MakeTlsConnector::new(connector); + let (client, connection) = tokio_postgres::connect(path.as_str(), connector).await.migration_err("could not connect to database", None)?; + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + $op(client).await + } else { + let (client, connection) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?; + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); + $op(client).await + } } else { panic!("tried to migrate async from config for a postgresql database, but tokio-postgres was not enabled!"); }