Skip to content

Commit 0d50403

Browse files
authored
Server TLS (#417)
* Server TLS * Finish up TLS * thats it * diff * remove dead code * maybe? * dirty shutdown * skip flakey test * remove unused error * fetch config once
1 parent 4a87b48 commit 0d50403

File tree

10 files changed

+311
-33
lines changed

10 files changed

+311
-33
lines changed

Cargo.lock

+32
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ nix = "0.26.2"
3939
atomic_enum = "0.2.0"
4040
postgres-protocol = "0.6.5"
4141
fallible-iterator = "0.2"
42+
pin-project = "1"
43+
webpki-roots = "0.23"
44+
rustls = { version = "0.21", features = ["dangerous_configuration"] }
4245

4346
[target.'cfg(not(target_env = "msvc"))'.dependencies]
4447
jemallocator = "0.5.0"

pgcat.toml

+8-2
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,15 @@ tcp_keepalives_count = 5
6161
tcp_keepalives_interval = 5
6262

6363
# Path to TLS Certificate file to use for TLS connections
64-
# tls_certificate = "server.cert"
64+
# tls_certificate = ".circleci/server.cert"
6565
# Path to TLS private key file to use for TLS connections
66-
# tls_private_key = "server.key"
66+
# tls_private_key = ".circleci/server.key"
67+
68+
# Enable/disable server TLS
69+
server_tls = false
70+
71+
# Verify server certificate is completely authentic.
72+
verify_server_certificate = false
6773

6874
# User name to access the virtual administrative database (pgbouncer or pgcat)
6975
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..

src/client.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ where
539539
Some(md5_hash_password(username, password, &salt))
540540
} else {
541541
if !get_config().is_auth_query_configured() {
542+
wrong_password(&mut write, username).await?;
542543
return Err(Error::ClientAuthImpossible(username.into()));
543544
}
544545

@@ -565,6 +566,8 @@ where
565566
}
566567

567568
Err(err) => {
569+
wrong_password(&mut write, username).await?;
570+
568571
return Err(Error::ClientAuthPassthroughError(
569572
err.to_string(),
570573
client_identifier,
@@ -587,7 +590,15 @@ where
587590
client_identifier
588591
);
589592

590-
let fetched_hash = refetch_auth_hash(&pool).await?;
593+
let fetched_hash = match refetch_auth_hash(&pool).await {
594+
Ok(fetched_hash) => fetched_hash,
595+
Err(err) => {
596+
wrong_password(&mut write, username).await?;
597+
598+
return Err(err);
599+
}
600+
};
601+
591602
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
592603

593604
// Ok password changed in server an auth is possible.

src/config.rs

+14
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,13 @@ pub struct General {
281281

282282
pub tls_certificate: Option<String>,
283283
pub tls_private_key: Option<String>,
284+
285+
#[serde(default)] // false
286+
pub server_tls: bool,
287+
288+
#[serde(default)] // false
289+
pub verify_server_certificate: bool,
290+
284291
pub admin_username: String,
285292
pub admin_password: String,
286293

@@ -373,6 +380,8 @@ impl Default for General {
373380
autoreload: None,
374381
tls_certificate: None,
375382
tls_private_key: None,
383+
server_tls: false,
384+
verify_server_certificate: false,
376385
admin_username: String::from("admin"),
377386
admin_password: String::from("admin"),
378387
auth_query: None,
@@ -852,6 +861,11 @@ impl Config {
852861
info!("TLS support is disabled");
853862
}
854863
};
864+
info!("Server TLS enabled: {}", self.general.server_tls);
865+
info!(
866+
"Server TLS certificate verification: {}",
867+
self.general.verify_server_certificate
868+
);
855869

856870
for (pool_name, pool_config) in &self.pools {
857871
// TODO: Make this output prettier (maybe a table?)

src/messages.rs

+42-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,10 @@ where
116116

117117
/// Send the startup packet the server. We're pretending we're a Pg client.
118118
/// This tells the server which user we are and what database we want.
119-
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
119+
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
120+
where
121+
S: tokio::io::AsyncWrite + std::marker::Unpin,
122+
{
120123
let mut bytes = BytesMut::with_capacity(25);
121124

122125
bytes.put_i32(196608); // Protocol number
@@ -150,6 +153,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
150153
}
151154
}
152155

156+
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
157+
let mut bytes = BytesMut::with_capacity(12);
158+
159+
bytes.put_i32(8);
160+
bytes.put_i32(80877103);
161+
162+
match stream.write_all(&bytes).await {
163+
Ok(_) => Ok(()),
164+
Err(err) => Err(Error::SocketError(format!(
165+
"Error writing SSLRequest to server socket - Error: {:?}",
166+
err
167+
))),
168+
}
169+
}
170+
153171
/// Parse the params the server sends as a key/value format.
154172
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
155173
let mut result = HashMap::new();
@@ -505,6 +523,29 @@ where
505523
}
506524
}
507525

526+
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
527+
where
528+
S: tokio::io::AsyncWrite + std::marker::Unpin,
529+
{
530+
match stream.write_all(buf).await {
531+
Ok(_) => match stream.flush().await {
532+
Ok(_) => Ok(()),
533+
Err(err) => {
534+
return Err(Error::SocketError(format!(
535+
"Error flushing socket - Error: {:?}",
536+
err
537+
)))
538+
}
539+
},
540+
Err(err) => {
541+
return Err(Error::SocketError(format!(
542+
"Error writing to socket - Error: {:?}",
543+
err
544+
)))
545+
}
546+
}
547+
}
548+
508549
/// Read a complete message from the socket.
509550
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
510551
where

src/pool.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,7 @@ impl ConnectionPool {
376376
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
377377
.test_on_check_out(false)
378378
.build(manager)
379-
.await
380-
.unwrap();
379+
.await?;
381380

382381
pools.push(pool);
383382
servers.push(address);

0 commit comments

Comments
 (0)