Skip to content

Commit c104b23

Browse files
authored
Merge pull request #1151 from sunng87/feature/ssl-negotiation
feat: sslnegotiation and direct ssl for postgres 17
2 parents 07b6878 + 720ffe8 commit c104b23

27 files changed

+203
-56
lines changed

Diff for: docker-compose.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
version: '2'
22
services:
33
postgres:
4-
image: postgres:14
4+
image: docker.io/postgres:17
55
ports:
66
- 5433:5433
77
volumes:

Diff for: postgres-native-tls/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ default = ["runtime"]
1616
runtime = ["tokio-postgres/runtime"]
1717

1818
[dependencies]
19-
native-tls = "0.2"
19+
native-tls = { version = "0.2", features = ["alpn"] }
2020
tokio = "1.0"
2121
tokio-native-tls = "0.3"
2222
tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-features = false }

Diff for: postgres-native-tls/src/lib.rs

+8
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
//! ```
5454
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
5555

56+
use native_tls::TlsConnectorBuilder;
5657
use std::future::Future;
5758
use std::io;
5859
use std::pin::Pin;
@@ -180,3 +181,10 @@ where
180181
}
181182
}
182183
}
184+
185+
/// Set ALPN for `TlsConnectorBuilder`
186+
///
187+
/// This is required when using `sslnegotiation=direct`
188+
pub fn set_postgresql_alpn(builder: &mut TlsConnectorBuilder) {
189+
builder.request_alpns(&["postgresql"]);
190+
}

Diff for: postgres-native-tls/src/test.rs

+16-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use tokio_postgres::tls::TlsConnect;
55

66
#[cfg(feature = "runtime")]
77
use crate::MakeTlsConnector;
8-
use crate::TlsConnector;
8+
use crate::{set_postgresql_alpn, TlsConnector};
99

1010
async fn smoke_test<T>(s: &str, tls: T)
1111
where
@@ -42,6 +42,21 @@ async fn require() {
4242
.await;
4343
}
4444

45+
#[tokio::test]
46+
async fn direct() {
47+
let mut builder = native_tls::TlsConnector::builder();
48+
builder.add_root_certificate(
49+
Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(),
50+
);
51+
set_postgresql_alpn(&mut builder);
52+
let connector = builder.build().unwrap();
53+
smoke_test(
54+
"user=ssl_user dbname=postgres sslmode=require sslnegotiation=direct",
55+
TlsConnector::new(connector, "localhost"),
56+
)
57+
.await;
58+
}
59+
4560
#[tokio::test]
4661
async fn prefer() {
4762
let connector = native_tls::TlsConnector::builder()

Diff for: postgres-openssl/src/lib.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ use openssl::hash::MessageDigest;
5353
use openssl::nid::Nid;
5454
#[cfg(feature = "runtime")]
5555
use openssl::ssl::SslConnector;
56-
use openssl::ssl::{self, ConnectConfiguration, SslRef};
56+
use openssl::ssl::{self, ConnectConfiguration, SslConnectorBuilder, SslRef};
5757
use openssl::x509::X509VerifyResult;
5858
use std::error::Error;
5959
use std::fmt::{self, Debug};
@@ -250,3 +250,10 @@ fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
250250
};
251251
cert.digest(md).ok().map(|b| b.to_vec())
252252
}
253+
254+
/// Set ALPN for `SslConnectorBuilder`
255+
///
256+
/// This is required when using `sslnegotiation=direct`
257+
pub fn set_postgresql_alpn(builder: &mut SslConnectorBuilder) -> Result<(), ErrorStack> {
258+
builder.set_alpn_protos(b"\x0apostgresql")
259+
}

Diff for: postgres-openssl/src/test.rs

+13
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ async fn require() {
3737
.await;
3838
}
3939

40+
#[tokio::test]
41+
async fn direct() {
42+
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
43+
builder.set_ca_file("../test/server.crt").unwrap();
44+
set_postgresql_alpn(&mut builder).unwrap();
45+
let ctx = builder.build();
46+
smoke_test(
47+
"user=ssl_user dbname=postgres sslmode=require sslnegotiation=direct",
48+
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
49+
)
50+
.await;
51+
}
52+
4053
#[tokio::test]
4154
async fn prefer() {
4255
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();

Diff for: postgres-protocol/src/message/backend.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ pub struct ColumnFormats<'a> {
475475
remaining: u16,
476476
}
477477

478-
impl<'a> FallibleIterator for ColumnFormats<'a> {
478+
impl FallibleIterator for ColumnFormats<'_> {
479479
type Item = u16;
480480
type Error = io::Error;
481481

@@ -557,7 +557,7 @@ pub struct DataRowRanges<'a> {
557557
remaining: u16,
558558
}
559559

560-
impl<'a> FallibleIterator for DataRowRanges<'a> {
560+
impl FallibleIterator for DataRowRanges<'_> {
561561
type Item = Option<Range<usize>>;
562562
type Error = io::Error;
563563

@@ -645,7 +645,7 @@ pub struct ErrorField<'a> {
645645
value: &'a [u8],
646646
}
647647

648-
impl<'a> ErrorField<'a> {
648+
impl ErrorField<'_> {
649649
#[inline]
650650
pub fn type_(&self) -> u8 {
651651
self.type_
@@ -717,7 +717,7 @@ pub struct Parameters<'a> {
717717
remaining: u16,
718718
}
719719

720-
impl<'a> FallibleIterator for Parameters<'a> {
720+
impl FallibleIterator for Parameters<'_> {
721721
type Item = Oid;
722722
type Error = io::Error;
723723

Diff for: postgres-protocol/src/types/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ impl<'a> Array<'a> {
582582
/// An iterator over the dimensions of an array.
583583
pub struct ArrayDimensions<'a>(&'a [u8]);
584584

585-
impl<'a> FallibleIterator for ArrayDimensions<'a> {
585+
impl FallibleIterator for ArrayDimensions<'_> {
586586
type Item = ArrayDimension;
587587
type Error = StdBox<dyn Error + Sync + Send>;
588588

@@ -950,7 +950,7 @@ pub struct PathPoints<'a> {
950950
buf: &'a [u8],
951951
}
952952

953-
impl<'a> FallibleIterator for PathPoints<'a> {
953+
impl FallibleIterator for PathPoints<'_> {
954954
type Item = Point;
955955
type Error = StdBox<dyn Error + Sync + Send>;
956956

Diff for: postgres-types/src/lib.rs

+10-10
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ pub enum Format {
914914
Binary,
915915
}
916916

917-
impl<'a, T> ToSql for &'a T
917+
impl<T> ToSql for &T
918918
where
919919
T: ToSql,
920920
{
@@ -963,7 +963,7 @@ impl<T: ToSql> ToSql for Option<T> {
963963
to_sql_checked!();
964964
}
965965

966-
impl<'a, T: ToSql> ToSql for &'a [T] {
966+
impl<T: ToSql> ToSql for &[T] {
967967
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
968968
let member_type = match *ty.kind() {
969969
Kind::Array(ref member) => member,
@@ -1004,7 +1004,7 @@ impl<'a, T: ToSql> ToSql for &'a [T] {
10041004
to_sql_checked!();
10051005
}
10061006

1007-
impl<'a> ToSql for &'a [u8] {
1007+
impl ToSql for &[u8] {
10081008
fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
10091009
types::bytea_to_sql(self, w);
10101010
Ok(IsNull::No)
@@ -1064,7 +1064,7 @@ impl<T: ToSql> ToSql for Box<[T]> {
10641064
to_sql_checked!();
10651065
}
10661066

1067-
impl<'a> ToSql for Cow<'a, [u8]> {
1067+
impl ToSql for Cow<'_, [u8]> {
10681068
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
10691069
<&[u8] as ToSql>::to_sql(&self.as_ref(), ty, w)
10701070
}
@@ -1088,7 +1088,7 @@ impl ToSql for Vec<u8> {
10881088
to_sql_checked!();
10891089
}
10901090

1091-
impl<'a> ToSql for &'a str {
1091+
impl ToSql for &str {
10921092
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
10931093
match ty.name() {
10941094
"ltree" => types::ltree_to_sql(self, w),
@@ -1109,7 +1109,7 @@ impl<'a> ToSql for &'a str {
11091109
to_sql_checked!();
11101110
}
11111111

1112-
impl<'a> ToSql for Cow<'a, str> {
1112+
impl ToSql for Cow<'_, str> {
11131113
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
11141114
<&str as ToSql>::to_sql(&self.as_ref(), ty, w)
11151115
}
@@ -1256,17 +1256,17 @@ impl BorrowToSql for &dyn ToSql {
12561256
}
12571257
}
12581258

1259-
impl<'a> sealed::Sealed for Box<dyn ToSql + Sync + 'a> {}
1259+
impl sealed::Sealed for Box<dyn ToSql + Sync + '_> {}
12601260

1261-
impl<'a> BorrowToSql for Box<dyn ToSql + Sync + 'a> {
1261+
impl BorrowToSql for Box<dyn ToSql + Sync + '_> {
12621262
#[inline]
12631263
fn borrow_to_sql(&self) -> &dyn ToSql {
12641264
self.as_ref()
12651265
}
12661266
}
12671267

1268-
impl<'a> sealed::Sealed for Box<dyn ToSql + Sync + Send + 'a> {}
1269-
impl<'a> BorrowToSql for Box<dyn ToSql + Sync + Send + 'a> {
1268+
impl sealed::Sealed for Box<dyn ToSql + Sync + Send + '_> {}
1269+
impl BorrowToSql for Box<dyn ToSql + Sync + Send + '_> {
12701270
#[inline]
12711271
fn borrow_to_sql(&self) -> &dyn ToSql {
12721272
self.as_ref()

Diff for: postgres/src/config.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use std::time::Duration;
1212
use tokio::runtime;
1313
#[doc(inline)]
1414
pub use tokio_postgres::config::{
15-
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
15+
ChannelBinding, Host, LoadBalanceHosts, SslMode, SslNegotiation, TargetSessionAttrs,
1616
};
1717
use tokio_postgres::error::DbError;
1818
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
@@ -40,6 +40,9 @@ use tokio_postgres::{Error, Socket};
4040
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
4141
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
4242
/// with the `connect` method.
43+
/// * `sslnegotiation` - TLS negotiation method. If set to `direct`, the client will perform direct TLS handshake, this only works for PostgreSQL 17 and newer.
44+
/// Note that you will need to setup ALPN of TLS client configuration to `postgresql` when using direct TLS.
45+
/// If set to `postgres`, the default value, it follows original postgres wire protocol to perform the negotiation.
4346
/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format,
4447
/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses.
4548
/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address,
@@ -230,6 +233,17 @@ impl Config {
230233
self.config.get_ssl_mode()
231234
}
232235

236+
/// Sets the SSL negotiation method
237+
pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
238+
self.config.ssl_negotiation(ssl_negotiation);
239+
self
240+
}
241+
242+
/// Gets the SSL negotiation method
243+
pub fn get_ssl_negotiation(&self) -> SslNegotiation {
244+
self.config.get_ssl_negotiation()
245+
}
246+
233247
/// Adds a host to the configuration.
234248
///
235249
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix

Diff for: postgres/src/notifications.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ pub struct Iter<'a> {
7777
connection: ConnectionRef<'a>,
7878
}
7979

80-
impl<'a> FallibleIterator for Iter<'a> {
80+
impl FallibleIterator for Iter<'_> {
8181
type Item = Notification;
8282
type Error = Error;
8383

@@ -100,7 +100,7 @@ pub struct BlockingIter<'a> {
100100
connection: ConnectionRef<'a>,
101101
}
102102

103-
impl<'a> FallibleIterator for BlockingIter<'a> {
103+
impl FallibleIterator for BlockingIter<'_> {
104104
type Item = Notification;
105105
type Error = Error;
106106

@@ -129,7 +129,7 @@ pub struct TimeoutIter<'a> {
129129
timeout: Duration,
130130
}
131131

132-
impl<'a> FallibleIterator for TimeoutIter<'a> {
132+
impl FallibleIterator for TimeoutIter<'_> {
133133
type Item = Notification;
134134
type Error = Error;
135135

Diff for: postgres/src/transaction.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub struct Transaction<'a> {
1212
transaction: Option<tokio_postgres::Transaction<'a>>,
1313
}
1414

15-
impl<'a> Drop for Transaction<'a> {
15+
impl Drop for Transaction<'_> {
1616
fn drop(&mut self) {
1717
if let Some(transaction) = self.transaction.take() {
1818
let _ = self.connection.block_on(transaction.rollback());

Diff for: tokio-postgres/src/cancel_query.rs

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use crate::client::SocketConfig;
2-
use crate::config::SslMode;
2+
use crate::config::{SslMode, SslNegotiation};
33
use crate::tls::MakeTlsConnect;
44
use crate::{cancel_query_raw, connect_socket, Error, Socket};
55
use std::io;
66

77
pub(crate) async fn cancel_query<T>(
88
config: Option<SocketConfig>,
99
ssl_mode: SslMode,
10+
ssl_negotiation: SslNegotiation,
1011
mut tls: T,
1112
process_id: i32,
1213
secret_key: i32,
@@ -38,6 +39,14 @@ where
3839
)
3940
.await?;
4041

41-
cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, has_hostname, process_id, secret_key)
42-
.await
42+
cancel_query_raw::cancel_query_raw(
43+
socket,
44+
ssl_mode,
45+
ssl_negotiation,
46+
tls,
47+
has_hostname,
48+
process_id,
49+
secret_key,
50+
)
51+
.await
4352
}

Diff for: tokio-postgres/src/cancel_query_raw.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::config::SslMode;
1+
use crate::config::{SslMode, SslNegotiation};
22
use crate::tls::TlsConnect;
33
use crate::{connect_tls, Error};
44
use bytes::BytesMut;
@@ -8,6 +8,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
88
pub async fn cancel_query_raw<S, T>(
99
stream: S,
1010
mode: SslMode,
11+
negotiation: SslNegotiation,
1112
tls: T,
1213
has_hostname: bool,
1314
process_id: i32,
@@ -17,7 +18,7 @@ where
1718
S: AsyncRead + AsyncWrite + Unpin,
1819
T: TlsConnect<S>,
1920
{
20-
let mut stream = connect_tls::connect_tls(stream, mode, tls, has_hostname).await?;
21+
let mut stream = connect_tls::connect_tls(stream, mode, negotiation, tls, has_hostname).await?;
2122

2223
let mut buf = BytesMut::new();
2324
frontend::cancel_request(process_id, secret_key, &mut buf);

Diff for: tokio-postgres/src/cancel_token.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::config::SslMode;
1+
use crate::config::{SslMode, SslNegotiation};
22
use crate::tls::TlsConnect;
33
#[cfg(feature = "runtime")]
44
use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect, Socket};
@@ -12,6 +12,7 @@ pub struct CancelToken {
1212
#[cfg(feature = "runtime")]
1313
pub(crate) socket_config: Option<SocketConfig>,
1414
pub(crate) ssl_mode: SslMode,
15+
pub(crate) ssl_negotiation: SslNegotiation,
1516
pub(crate) process_id: i32,
1617
pub(crate) secret_key: i32,
1718
}
@@ -37,6 +38,7 @@ impl CancelToken {
3738
cancel_query::cancel_query(
3839
self.socket_config.clone(),
3940
self.ssl_mode,
41+
self.ssl_negotiation,
4042
tls,
4143
self.process_id,
4244
self.secret_key,
@@ -54,6 +56,7 @@ impl CancelToken {
5456
cancel_query_raw::cancel_query_raw(
5557
stream,
5658
self.ssl_mode,
59+
self.ssl_negotiation,
5760
tls,
5861
true,
5962
self.process_id,

0 commit comments

Comments
 (0)