-
Notifications
You must be signed in to change notification settings - Fork 220
/
Copy pathserver.rs
710 lines (583 loc) · 27.7 KB
/
server.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
/// Implementation of the PostgreSQL server (database) protocol.
/// Here we are pretending to the a Postgres client.
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace, warn};
use std::io::Read;
use std::time::SystemTime;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use crate::config::{Address, User};
use crate::constants::*;
use crate::errors::Error;
use crate::messages::*;
use crate::pool::ClientServerMap;
use crate::scram::ScramSha256;
use crate::stats::Reporter;
/// Server state.
pub struct Server {
server_id: i32,
/// Server host, e.g. localhost,
/// port, e.g. 5432, and role, e.g. primary or replica.
address: Address,
/// Buffered read socket.
read: BufReader<OwnedReadHalf>,
/// Unbuffered write socket (our client code buffers).
write: OwnedWriteHalf,
/// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut,
/// Server information the server sent us over on startup.
server_info: BytesMut,
/// Backend id and secret key used for query cancellation.
process_id: i32,
secret_key: i32,
/// Is the server inside a transaction or idle.
in_transaction: bool,
/// Is there more data for the client to read.
data_available: bool,
/// Is the server broken? We'll remote it from the pool if so.
bad: bool,
/// If server connection requires a DISCARD ALL before checkin
needs_cleanup: bool,
/// Mapping of clients and servers used for query cancellation.
client_server_map: ClientServerMap,
/// Server connected at.
connected_at: chrono::naive::NaiveDateTime,
/// Reports various metrics, e.g. data sent & received.
stats: Reporter,
/// Application name using the server at the moment.
application_name: String,
// Last time that a successful server send or response happened
last_activity: SystemTime,
}
impl Server {
/// Pretend to be the Postgres client and connect to the server given host, port and credentials.
/// Perform the authentication and return the server in a ready for query state.
pub async fn startup(
server_id: i32,
address: &Address,
user: &User,
database: &str,
client_server_map: ClientServerMap,
stats: Reporter,
) -> Result<Server, Error> {
let mut stream =
match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
Ok(stream) => stream,
Err(err) => {
error!("Could not connect to server: {}", err);
return Err(Error::SocketError(format!(
"Could not connect to server: {}",
err
)));
}
};
configure_socket(&stream);
trace!("Sending StartupMessage");
// StartupMessage
startup(&mut stream, &user.username, database).await?;
let mut server_info = BytesMut::new();
let mut process_id: i32 = 0;
let mut secret_key: i32 = 0;
// We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete.
let mut scram = ScramSha256::new(&user.password);
loop {
let code = match stream.read_u8().await {
Ok(code) => code as char,
Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
trace!("Message: {}", code);
match code {
// Authentication
'R' => {
// Determine which kind of authentication is required, if any.
let auth_code = match stream.read_i32().await {
Ok(auth_code) => auth_code,
Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
trace!("Auth: {}", auth_code);
match auth_code {
MD5_ENCRYPTED_PASSWORD => {
// The salt is 4 bytes.
// See: https://www.postgresql.org/docs/12/protocol-message-formats.html
let mut salt = vec![0u8; 4];
match stream.read_exact(&mut salt).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
md5_password(&mut stream, &user.username, &user.password, &salt[..])
.await?;
}
AUTHENTICATION_SUCCESSFUL => (),
SASL => {
debug!("Starting SASL authentication");
let sasl_len = (len - 8) as usize;
let mut sasl_auth = vec![0u8; sasl_len];
match stream.read_exact(&mut sasl_auth).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
let sasl_types: Vec<_> = sasl_auth[..sasl_len - 2]
.split(|&b| b == 0)
.map(|v| String::from_utf8_lossy(v).to_string())
.collect();
if sasl_types.contains(&SCRAM_SHA_256.to_string()) {
debug!("Using {}", SCRAM_SHA_256);
// Generate client message.
let sasl_response = scram.message();
// SASLInitialResponse (F)
let mut res = BytesMut::new();
res.put_u8(b'p');
// length + String length + length + length of sasl response
res.put_i32(
4 // i32 size
+ SCRAM_SHA_256.len() as i32 // length of SASL version string,
+ 1 // Null terminator for the SASL version string,
+ 4 // i32 size
+ sasl_response.len() as i32, // length of SASL response
);
res.put_slice(format!("{}\0", SCRAM_SHA_256).as_bytes());
res.put_i32(sasl_response.len() as i32);
res.put(sasl_response);
write_all(&mut stream, res).await?;
} else {
error!("Unsupported SCRAM version: {:?}", sasl_types);
return Err(Error::ServerError);
}
}
SASL_CONTINUE => {
trace!("Continuing SASL");
let mut sasl_data = vec![0u8; (len - 8) as usize];
match stream.read_exact(&mut sasl_data).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
let msg = BytesMut::from(&sasl_data[..]);
let sasl_response = scram.update(&msg)?;
// SASLResponse
let mut res = BytesMut::new();
res.put_u8(b'p');
res.put_i32(4 + sasl_response.len() as i32);
res.put(sasl_response);
write_all(&mut stream, res).await?;
}
SASL_FINAL => {
trace!("Final SASL");
let mut sasl_final = vec![0u8; len as usize - 8];
match stream.read_exact(&mut sasl_final).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
match scram.finish(&BytesMut::from(&sasl_final[..])) {
Ok(_) => {
debug!("SASL authentication successful");
}
Err(err) => {
debug!("SASL authentication failed");
return Err(err);
}
};
}
_ => {
error!("Unsupported authentication mechanism: {}", auth_code);
return Err(Error::ServerError);
}
}
}
// ErrorResponse
'E' => {
let error_code = match stream.read_u8().await {
Ok(error_code) => error_code,
Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
trace!("Error: {}", error_code);
match error_code {
// No error message is present in the message.
MESSAGE_TERMINATOR => (),
// An error message will be present.
_ => {
// Read the error message without the terminating null character.
let mut error = vec![0u8; len as usize - 4 - 1];
match stream.read_exact(&mut error).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
// TODO: the error message contains multiple fields; we can decode them and
// present a prettier message to the user.
// See: https://www.postgresql.org/docs/12/protocol-error-fields.html
error!("Server error: {}", String::from_utf8_lossy(&error));
}
};
return Err(Error::ServerError);
}
// ParameterStatus
'S' => {
let mut param = vec![0u8; len as usize - 4];
match stream.read_exact(&mut param).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
// Save the parameter so we can pass it to the client later.
// These can be server_encoding, client_encoding, server timezone, Postgres version,
// and many more interesting things we should know about the Postgres server we are talking to.
server_info.put_u8(b'S');
server_info.put_i32(len);
server_info.put_slice(¶m[..]);
}
// BackendKeyData
'K' => {
// The frontend must save these values if it wishes to be able to issue CancelRequest messages later.
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
process_id = match stream.read_i32().await {
Ok(id) => id,
Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
secret_key = match stream.read_i32().await {
Ok(id) => id,
Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
}
// ReadyForQuery
'Z' => {
let mut idle = vec![0u8; len as usize - 4];
match stream.read_exact(&mut idle).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
let (read, write) = stream.into_split();
let mut server = Server {
address: address.clone(),
read: BufReader::new(read),
write,
buffer: BytesMut::with_capacity(8196),
server_info,
server_id,
process_id,
secret_key,
in_transaction: false,
data_available: false,
bad: false,
needs_cleanup: false,
client_server_map,
connected_at: chrono::offset::Utc::now().naive_utc(),
stats,
application_name: String::new(),
last_activity: SystemTime::now(),
};
server.set_name("pgcat").await?;
return Ok(server);
}
// We have an unexpected message from the server during this exchange.
// Means we implemented the protocol wrong or we're not talking to a Postgres server.
_ => {
error!("Unknown code: {}", code);
return Err(Error::ProtocolSyncError(format!(
"Unknown server code: {}",
code
)));
}
};
}
}
/// Issue a query cancellation request to the server.
/// Uses a separate connection that's not part of the connection pool.
pub async fn cancel(
host: &str,
port: u16,
process_id: i32,
secret_key: i32,
) -> Result<(), Error> {
let mut stream = match TcpStream::connect(&format!("{}:{}", host, port)).await {
Ok(stream) => stream,
Err(err) => {
error!("Could not connect to server: {}", err);
return Err(Error::SocketError(format!("Error reading cancel message")));
}
};
configure_socket(&stream);
debug!("Sending CancelRequest");
let mut bytes = BytesMut::with_capacity(16);
bytes.put_i32(16);
bytes.put_i32(CANCEL_REQUEST_CODE);
bytes.put_i32(process_id);
bytes.put_i32(secret_key);
write_all(&mut stream, bytes).await
}
/// Send messages to the server from the client.
pub async fn send(&mut self, messages: &BytesMut) -> Result<(), Error> {
self.stats.data_sent(messages.len(), self.server_id);
match write_all_half(&mut self.write, messages).await {
Ok(_) => {
// Successfully sent to server
self.last_activity = SystemTime::now();
Ok(())
}
Err(err) => {
error!("Terminating server because of: {:?}", err);
self.bad = true;
Err(err)
}
}
}
/// Receive data from the server in response to a client request.
/// This method must be called multiple times while `self.is_data_available()` is true
/// in order to receive all data the server has to offer.
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
loop {
let mut message = match read_message(&mut self.read).await {
Ok(message) => message,
Err(err) => {
error!("Terminating server because of: {:?}", err);
self.bad = true;
return Err(err);
}
};
// Buffer the message we'll forward to the client later.
self.buffer.put(&message[..]);
let code = message.get_u8() as char;
let _len = message.get_i32();
trace!("Message: {}", code);
match code {
// ReadyForQuery
'Z' => {
let transaction_state = message.get_u8() as char;
match transaction_state {
// In transaction.
'T' => {
self.in_transaction = true;
}
// Idle, transaction over.
'I' => {
self.in_transaction = false;
}
// Some error occurred, the transaction was rolled back.
'E' => {
self.in_transaction = true;
}
// Something totally unexpected, this is not a Postgres server we know.
_ => {
self.bad = true;
return Err(Error::ProtocolSyncError(format!(
"Unknown transaction state: {}",
transaction_state
)));
}
};
// There is no more data available from the server.
self.data_available = false;
break;
}
// CommandComplete
'C' => {
let mut command_tag = String::new();
match message.reader().read_to_string(&mut command_tag) {
Ok(_) => {
// Non-exhaustive list of commands that are likely to change session variables/resources
// which can leak between clients. This is a best effort to block bad clients
// from poisoning a transaction-mode pool by setting inappropriate session variables
match command_tag.as_str() {
"SET\0" => {
// We don't detect set statements in transactions
// No great way to differentiate between set and set local
// As a result, we will miss cases when set statements are used in transactions
// This will reduce amount of discard statements sent
if !self.in_transaction {
debug!("Server connection marked for clean up");
self.needs_cleanup = true;
}
}
"PREPARE\0" => {
debug!("Server connection marked for clean up");
self.needs_cleanup = true;
}
_ => (),
}
}
Err(err) => {
warn!("Encountered an error while parsing CommandTag {}", err);
}
}
}
// DataRow
'D' => {
// More data is available after this message, this is not the end of the reply.
self.data_available = true;
// Don't flush yet, the more we buffer, the faster this goes...up to a limit.
if self.buffer.len() >= 8196 {
break;
}
}
// CopyInResponse: copy is starting from client to server.
'G' => break,
// CopyOutResponse: copy is starting from the server to the client.
'H' => {
self.data_available = true;
break;
}
// CopyData
'd' => {
// Don't flush yet, buffer until we reach limit
if self.buffer.len() >= 8196 {
break;
}
}
// CopyDone
// Buffer until ReadyForQuery shows up, so don't exit the loop yet.
'c' => (),
// Anything else, e.g. errors, notices, etc.
// Keep buffering until ReadyForQuery shows up.
_ => (),
};
}
let bytes = self.buffer.clone();
// Keep track of how much data we got from the server for stats.
self.stats.data_received(bytes.len(), self.server_id);
// Clear the buffer for next query.
self.buffer.clear();
// Successfully received data from server
self.last_activity = SystemTime::now();
// Pass the data back to the client.
Ok(bytes)
}
/// If the server is still inside a transaction.
/// If the client disconnects while the server is in a transaction, we will clean it up.
pub fn in_transaction(&self) -> bool {
debug!("Server in transaction: {}", self.in_transaction);
self.in_transaction
}
/// We don't buffer all of server responses, e.g. COPY OUT produces too much data.
/// The client is responsible to call `self.recv()` while this method returns true.
pub fn is_data_available(&self) -> bool {
self.data_available
}
/// Server & client are out of sync, we must discard this connection.
/// This happens with clients that misbehave.
pub fn is_bad(&self) -> bool {
self.bad
}
/// Get server startup information to forward it to the client.
/// Not used at the moment.
pub fn server_info(&self) -> BytesMut {
self.server_info.clone()
}
/// Indicate that this server connection cannot be re-used and must be discarded.
pub fn mark_bad(&mut self) {
error!("Server {:?} marked bad", self.address);
self.bad = true;
}
/// Claim this server as mine for the purposes of query cancellation.
pub fn claim(&mut self, process_id: i32, secret_key: i32) {
let mut guard = self.client_server_map.lock();
guard.insert(
(process_id, secret_key),
(
self.process_id,
self.secret_key,
self.address.host.clone(),
self.address.port,
),
);
}
/// Execute an arbitrary query against the server.
/// It will use the simple query protocol.
/// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`.
pub async fn query(&mut self, query: &str) -> Result<(), Error> {
let query = simple_query(query);
self.send(&query).await?;
loop {
let _ = self.recv().await?;
if !self.data_available {
break;
}
}
Ok(())
}
/// Perform any necessary cleanup before putting the server
/// connection back in the pool
pub async fn checkin_cleanup(&mut self) -> Result<(), Error> {
// Client disconnected with an open transaction on the server connection.
// Pgbouncer behavior is to close the server connection but that can cause
// server connection thrashing if clients repeatedly do this.
// Instead, we ROLLBACK that transaction before putting the connection back in the pool
if self.in_transaction() {
warn!("Server returned while still in transaction, rolling back transaction");
self.query("ROLLBACK").await?;
}
// Client disconnected but it performed session-altering operations such as
// SET statement_timeout to 1 or create a prepared statement. We clear that
// to avoid leaking state between clients. For performance reasons we only
// send `DISCARD ALL` if we think the session is altered instead of just sending
// it before each checkin.
if self.needs_cleanup {
warn!("Server returned with session state altered, discarding state");
self.query("DISCARD ALL").await?;
self.needs_cleanup = false;
}
Ok(())
}
/// A shorthand for `SET application_name = $1`.
pub async fn set_name(&mut self, name: &str) -> Result<(), Error> {
if self.application_name != name {
self.application_name = name.to_string();
// We don't want `SET application_name` to mark the server connection
// as needing cleanup
let needs_cleanup_before = self.needs_cleanup;
let result = Ok(self
.query(&format!("SET application_name = '{}'", name))
.await?);
self.needs_cleanup = needs_cleanup_before;
result
} else {
Ok(())
}
}
/// Get the servers address.
#[allow(dead_code)]
pub fn address(&self) -> Address {
self.address.clone()
}
/// Get the server connection identifier
/// Used to uniquely identify connection in statistics
pub fn server_id(&self) -> i32 {
self.server_id
}
// Get server's latest response timestamp
pub fn last_activity(&self) -> SystemTime {
self.last_activity
}
// Marks a connection as needing DISCARD ALL at checkin
pub fn mark_dirty(&mut self) {
self.needs_cleanup = true;
}
}
impl Drop for Server {
/// Try to do a clean shut down. Best effort because
/// the socket is in non-blocking mode, so it may not be ready
/// for a write.
fn drop(&mut self) {
self.stats.server_disconnecting(self.server_id);
let mut bytes = BytesMut::with_capacity(4);
bytes.put_u8(b'X');
bytes.put_i32(4);
match self.write.try_write(&bytes) {
Ok(_) => (),
Err(_) => debug!("Dirty shutdown"),
};
// Should not matter.
self.bad = true;
let now = chrono::offset::Utc::now().naive_utc();
let duration = now - self.connected_at;
info!(
"Server connection closed {:?}, session duration: {}",
self.address,
crate::format_duration(&duration)
);
}
}