From d7bf2de9bff40c7e65a5efdb4d86c5ee77d92199 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 14 Oct 2022 11:10:17 -0400 Subject: [PATCH 01/35] initial commit --- src/admin.rs | 18 ++++++++-------- src/client.rs | 56 ++++++++++++++++++++++--------------------------- src/messages.rs | 10 ++++----- src/server.rs | 25 +++++++++------------- 4 files changed, 49 insertions(+), 60 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index 42af315e..d95fd46b 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -168,7 +168,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show PgCat version. @@ -186,7 +186,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show utilization of connection pools for each shard and replicas. @@ -247,7 +247,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show shards and replicas. @@ -314,7 +314,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Ignore any SET commands the client sends. @@ -346,7 +346,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Shows current configuration. @@ -392,7 +392,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show shard and replicas statistics. @@ -452,7 +452,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show currently connected clients @@ -502,7 +502,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show currently connected servers @@ -556,5 +556,5 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } diff --git a/src/client.rs b/src/client.rs index e72dbf79..5798359f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -37,10 +37,6 @@ pub struct Client { /// better than a stock buffer. write: T, - /// Internal buffer, where we place messages until we have to flush - /// them to the backend. - buffer: BytesMut, - /// Address addr: std::net::SocketAddr, @@ -491,7 +487,6 @@ where read: BufReader::new(read), write: write, addr, - buffer: BytesMut::with_capacity(8196), cancel_mode: false, transaction_mode, process_id, @@ -525,7 +520,6 @@ where read: BufReader::new(read), write: write, addr, - buffer: BytesMut::with_capacity(8196), cancel_mode: true, transaction_mode: false, process_id, @@ -586,6 +580,10 @@ where self.application_name.clone(), ); + // Internal buffer, where we place messages until we have to flush + // them to the backend. + let mut message_buffer = BytesMut::with_capacity(8196); + // Our custom protocol loop. // We expect the client to either start a transaction with regular queries // or issue commands for our sharding and server selection protocol. @@ -628,7 +626,7 @@ where // to the client so we buffer them and defer the decision to error out or not // to when we get the S message 'P' | 'B' | 'D' | 'E' => { - self.buffer.put(&message[..]); + message_buffer.put(&message[..]); continue; } 'X' => { @@ -754,7 +752,7 @@ where // protocol buffer if message[0] as char == 'S' { error!("Got Sync message but failed to get a connection from the pool"); - self.buffer.clear(); + message_buffer.clear(); } error_response(&mut self.write, "could not get connection from the pool") .await?; @@ -836,7 +834,7 @@ where 'Q' => { debug!("Sending query to server"); - self.send_and_receive_loop(code, message, server, &address, &pool) + self.send_and_receive_loop(code, &message_buffer, server, &address, &pool) .await?; if !server.in_transaction() { @@ -862,25 +860,25 @@ where // Parse // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. 'P' => { - self.buffer.put(&message[..]); + message_buffer.put(&message[..]); } // Bind // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' 'B' => { - self.buffer.put(&message[..]); + message_buffer.put(&message[..]); } // Describe // Command a client can issue to describe a previously prepared named statement. 'D' => { - self.buffer.put(&message[..]); + message_buffer.put(&message[..]); } // Execute // Execute a prepared statement prepared in `P` and bound in `B`. 'E' => { - self.buffer.put(&message[..]); + message_buffer.put(&message[..]); } // Sync @@ -888,9 +886,9 @@ where 'S' => { debug!("Sending query to server"); - self.buffer.put(&message[..]); + message_buffer.put(&message[..]); - let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; + let first_message_code = (*message_buffer.get(0).unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' { @@ -898,7 +896,7 @@ where // P followed by 32 int followed by null-terminated statement name // So message code should be in offset 0 of the buffer, first character // in prepared statement name would be index 5 - let first_char_in_name = *self.buffer.get(5).unwrap_or(&0); + let first_char_in_name = *message_buffer.get(5).unwrap_or(&0); if first_char_in_name != 0 { // This is a named prepared statement // Server connection state will need to be cleared at checkin @@ -906,16 +904,10 @@ where } } - self.send_and_receive_loop( - code, - self.buffer.clone(), - server, - &address, - &pool, - ) - .await?; + self.send_and_receive_loop(code, &message_buffer, server, &address, &pool) + .await?; - self.buffer.clear(); + message_buffer.clear(); if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -932,19 +924,19 @@ where 'd' => { // Forward the data to the server, // don't buffer it since it can be rather large. - self.send_server_message(server, message, &address, &pool) + self.send_server_message(server, &message, &address, &pool) .await?; } // CopyDone or CopyFail // Copy is done, successfully or not. 'c' | 'f' => { - self.send_server_message(server, message, &address, &pool) + self.send_server_message(server, &message, &address, &pool) .await?; let response = self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, response).await { + match write_all_half(&mut self.write, &response).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -988,10 +980,12 @@ where guard.remove(&(self.process_id, self.secret_key)); } + /// Message is optional, if specified it will be sent to the server + /// otherwise the internal self.buffer will be used as the message async fn send_and_receive_loop( &mut self, code: char, - message: BytesMut, + message: &BytesMut, server: &mut Server, address: &Address, pool: &ConnectionPool, @@ -1007,7 +1001,7 @@ where loop { let response = self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, response).await { + match write_all_half(&mut self.write, &response).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -1033,7 +1027,7 @@ where async fn send_server_message( &self, server: &mut Server, - message: BytesMut, + message: &BytesMut, address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { diff --git a/src/messages.rs b/src/messages.rs index 78cb9dbf..8fffcfb5 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -254,7 +254,7 @@ where res.put_i32(len); res.put_slice(&set_complete[..]); - write_all_half(stream, res).await?; + write_all_half(stream, &res).await?; ready_for_query(stream).await } @@ -304,7 +304,7 @@ where res.put_i32(error.len() as i32 + 4); res.put(error); - Ok(write_all_half(stream, res).await?) + Ok(write_all_half(stream, &res).await?) } pub async fn wrong_password(stream: &mut S, user: &str) -> Result<(), Error> @@ -366,7 +366,7 @@ where // CommandComplete res.put(command_complete("SELECT 1")); - write_all_half(stream, res).await?; + write_all_half(stream, &res).await?; ready_for_query(stream).await } @@ -455,11 +455,11 @@ where } /// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). -pub async fn write_all_half(stream: &mut S, buf: BytesMut) -> Result<(), Error> +pub async fn write_all_half(stream: &mut S, buf: &BytesMut) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { - match stream.write_all(&buf).await { + match stream.write_all(buf).await { Ok(_) => Ok(()), Err(_) => return Err(Error::SocketError), } diff --git a/src/server.rs b/src/server.rs index d191eb74..899c5004 100644 --- a/src/server.rs +++ b/src/server.rs @@ -32,9 +32,6 @@ pub struct Server { /// Unbuffered write socket (our client code buffers). write: OwnedWriteHalf, - /// Our server response buffer. We buffer data before we give it to the client. - buffer: BytesMut, - /// Server information the server sent us over on startup. server_info: BytesMut, @@ -316,7 +313,6 @@ impl Server { address: address.clone(), read: BufReader::new(read), write: write, - buffer: BytesMut::with_capacity(8196), server_info: server_info, server_id: server_id, process_id: process_id, @@ -375,7 +371,7 @@ impl Server { } /// Send messages to the server from the client. - pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> { + pub async fn send(&mut self, messages: &BytesMut) -> Result<(), Error> { self.stats.data_sent(messages.len(), self.server_id); match write_all_half(&mut self.write, messages).await { @@ -396,6 +392,9 @@ impl Server { /// This method must be called multiple times while `self.is_data_available()` is true /// in order to receive all data the server has to offer. pub async fn recv(&mut self) -> Result { + // Our server response buffer. We buffer data before we give it to the client. + let mut message_buffer = BytesMut::with_capacity(8196); + loop { let mut message = match read_message(&mut self.read).await { Ok(message) => message, @@ -407,7 +406,7 @@ impl Server { }; // Buffer the message we'll forward to the client later. - self.buffer.put(&message[..]); + message_buffer.put(&message[..]); let code = message.get_u8() as char; let _len = message.get_i32(); @@ -487,7 +486,7 @@ impl Server { self.data_available = true; // Don't flush yet, the more we buffer, the faster this goes...up to a limit. - if self.buffer.len() >= 8196 { + if message_buffer.len() >= 8196 { break; } } @@ -515,19 +514,15 @@ impl Server { }; } - let bytes = self.buffer.clone(); - // Keep track of how much data we got from the server for stats. - self.stats.data_received(bytes.len(), self.server_id); - - // Clear the buffer for next query. - self.buffer.clear(); + self.stats + .data_received(message_buffer.len(), self.server_id); // Successfully received data from server self.last_activity = SystemTime::now(); // Pass the data back to the client. - Ok(bytes) + Ok(message_buffer) } /// If the server is still inside a transaction. @@ -580,7 +575,7 @@ impl Server { pub async fn query(&mut self, query: &str) -> Result<(), Error> { let query = simple_query(query); - self.send(query).await?; + self.send(&query).await?; loop { let _ = self.recv().await?; From b42c33ccfb06661a45313319202a8c7a5983cf55 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 14 Oct 2022 13:20:39 -0400 Subject: [PATCH 02/35] fix typo --- src/client.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/client.rs b/src/client.rs index 5798359f..b22c9f62 100644 --- a/src/client.rs +++ b/src/client.rs @@ -758,7 +758,7 @@ where .await?; error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}", - self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err); + self.pool_name, self.username, query_router.shard(), query_router.role(), err); continue; } }; @@ -834,7 +834,7 @@ where 'Q' => { debug!("Sending query to server"); - self.send_and_receive_loop(code, &message_buffer, server, &address, &pool) + self.send_and_receive_loop(code, &message, server, &address, &pool) .await?; if !server.in_transaction() { @@ -980,8 +980,6 @@ where guard.remove(&(self.process_id, self.secret_key)); } - /// Message is optional, if specified it will be sent to the server - /// otherwise the internal self.buffer will be used as the message async fn send_and_receive_loop( &mut self, code: char, From 3ae2953fa58d3b2a6b423ee2cdd5faf92917deb1 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 14 Oct 2022 14:28:18 -0400 Subject: [PATCH 03/35] use cursor for parse params instead of bytesmut --- src/client.rs | 2 +- src/messages.rs | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/client.rs b/src/client.rs index b22c9f62..fcb38fac 100644 --- a/src/client.rs +++ b/src/client.rs @@ -353,7 +353,7 @@ where ) -> Result, Error> { let config = get_config(); let stats = get_reporter(); - let parameters = parse_startup(bytes.clone())?; + let parameters = parse_startup(&bytes)?; // These two parameters are mandatory by the protocol. let pool_name = match parameters.get("database") { diff --git a/src/messages.rs b/src/messages.rs index 8fffcfb5..754a728a 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -7,6 +7,7 @@ use tokio::net::TcpStream; use crate::errors::Error; use std::collections::HashMap; +use std::io::Cursor; use std::mem; /// Postgres data type mappings @@ -141,18 +142,20 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu } /// Parse the params the server sends as a key/value format. -pub fn parse_params(mut bytes: BytesMut) -> Result, Error> { +pub fn parse_params(bytes: &BytesMut) -> Result, Error> { let mut result = HashMap::new(); let mut buf = Vec::new(); let mut tmp = String::new(); + let mut cursor = Cursor::new(bytes); + while bytes.has_remaining() { - let mut c = bytes.get_u8(); + let mut c = cursor.get_u8(); // Null-terminated C-strings. while c != 0 { tmp.push(c as char); - c = bytes.get_u8(); + c = cursor.get_u8(); } if tmp.len() > 0 { @@ -180,7 +183,7 @@ pub fn parse_params(mut bytes: BytesMut) -> Result, Erro /// Parse StartupMessage parameters. /// e.g. user, database, application_name, etc. -pub fn parse_startup(bytes: BytesMut) -> Result, Error> { +pub fn parse_startup(bytes: &BytesMut) -> Result, Error> { let result = parse_params(bytes)?; // Minimum required parameters From 33a0cad939769ad2769ed21a3f69f2611f170ba7 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 14 Oct 2022 14:30:53 -0400 Subject: [PATCH 04/35] undo --- src/client.rs | 2 +- src/messages.rs | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/client.rs b/src/client.rs index fcb38fac..23ecea79 100644 --- a/src/client.rs +++ b/src/client.rs @@ -353,7 +353,7 @@ where ) -> Result, Error> { let config = get_config(); let stats = get_reporter(); - let parameters = parse_startup(&bytes)?; + let parameters = parse_startup(bytes)?; // These two parameters are mandatory by the protocol. let pool_name = match parameters.get("database") { diff --git a/src/messages.rs b/src/messages.rs index 754a728a..8fffcfb5 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -7,7 +7,6 @@ use tokio::net::TcpStream; use crate::errors::Error; use std::collections::HashMap; -use std::io::Cursor; use std::mem; /// Postgres data type mappings @@ -142,20 +141,18 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu } /// Parse the params the server sends as a key/value format. -pub fn parse_params(bytes: &BytesMut) -> Result, Error> { +pub fn parse_params(mut bytes: BytesMut) -> Result, Error> { let mut result = HashMap::new(); let mut buf = Vec::new(); let mut tmp = String::new(); - let mut cursor = Cursor::new(bytes); - while bytes.has_remaining() { - let mut c = cursor.get_u8(); + let mut c = bytes.get_u8(); // Null-terminated C-strings. while c != 0 { tmp.push(c as char); - c = cursor.get_u8(); + c = bytes.get_u8(); } if tmp.len() > 0 { @@ -183,7 +180,7 @@ pub fn parse_params(bytes: &BytesMut) -> Result, Error> /// Parse StartupMessage parameters. /// e.g. user, database, application_name, etc. -pub fn parse_startup(bytes: &BytesMut) -> Result, Error> { +pub fn parse_startup(bytes: BytesMut) -> Result, Error> { let result = parse_params(bytes)?; // Minimum required parameters From 65bd10d6ff568dc7ebd446590671175e15293ea9 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 17 Oct 2022 10:52:01 -0700 Subject: [PATCH 05/35] Update to use a dedicated server message buffer --- src/client.rs | 52 +++++++++++++++++++++++++++++---------------------- src/server.rs | 17 ++++++++--------- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/src/client.rs b/src/client.rs index 23ecea79..7eb7de22 100644 --- a/src/client.rs +++ b/src/client.rs @@ -582,7 +582,9 @@ where // Internal buffer, where we place messages until we have to flush // them to the backend. - let mut message_buffer = BytesMut::with_capacity(8196); + let mut client_message_buffer = BytesMut::with_capacity(8196); + + let mut server_message_buffer = BytesMut::with_capacity(8196); // Our custom protocol loop. // We expect the client to either start a transaction with regular queries @@ -626,7 +628,7 @@ where // to the client so we buffer them and defer the decision to error out or not // to when we get the S message 'P' | 'B' | 'D' | 'E' => { - message_buffer.put(&message[..]); + client_message_buffer.put(&message[..]); continue; } 'X' => { @@ -752,7 +754,7 @@ where // protocol buffer if message[0] as char == 'S' { error!("Got Sync message but failed to get a connection from the pool"); - message_buffer.clear(); + client_message_buffer.clear(); } error_response(&mut self.write, "could not get connection from the pool") .await?; @@ -834,7 +836,7 @@ where 'Q' => { debug!("Sending query to server"); - self.send_and_receive_loop(code, &message, server, &address, &pool) + self.send_and_receive_loop(code, &message, server, &address, &pool, &mut server_message_buffer) .await?; if !server.in_transaction() { @@ -860,25 +862,25 @@ where // Parse // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. 'P' => { - message_buffer.put(&message[..]); + client_message_buffer.put(&message[..]); } // Bind // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' 'B' => { - message_buffer.put(&message[..]); + client_message_buffer.put(&message[..]); } // Describe // Command a client can issue to describe a previously prepared named statement. 'D' => { - message_buffer.put(&message[..]); + client_message_buffer.put(&message[..]); } // Execute // Execute a prepared statement prepared in `P` and bound in `B`. 'E' => { - message_buffer.put(&message[..]); + client_message_buffer.put(&message[..]); } // Sync @@ -886,9 +888,9 @@ where 'S' => { debug!("Sending query to server"); - message_buffer.put(&message[..]); + client_message_buffer.put(&message[..]); - let first_message_code = (*message_buffer.get(0).unwrap_or(&0)) as char; + let first_message_code = (*client_message_buffer.get(0).unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' { @@ -896,7 +898,7 @@ where // P followed by 32 int followed by null-terminated statement name // So message code should be in offset 0 of the buffer, first character // in prepared statement name would be index 5 - let first_char_in_name = *message_buffer.get(5).unwrap_or(&0); + let first_char_in_name = *client_message_buffer.get(5).unwrap_or(&0); if first_char_in_name != 0 { // This is a named prepared statement // Server connection state will need to be cleared at checkin @@ -904,10 +906,10 @@ where } } - self.send_and_receive_loop(code, &message_buffer, server, &address, &pool) + self.send_and_receive_loop(code, &client_message_buffer, server, &address, &pool, &mut server_message_buffer) .await?; - message_buffer.clear(); + client_message_buffer.clear(); if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -934,9 +936,9 @@ where self.send_server_message(server, &message, &address, &pool) .await?; - let response = self.receive_server_message(server, &address, &pool).await?; + self.receive_server_message(server, &address, &pool, &mut server_message_buffer).await?; - match write_all_half(&mut self.write, &response).await { + match write_all_half(&mut self.write, &server_message_buffer).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -944,6 +946,8 @@ where } }; + server_message_buffer.clear(); + if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -987,6 +991,7 @@ where server: &mut Server, address: &Address, pool: &ConnectionPool, + server_message_buffer: &mut BytesMut, ) -> Result<(), Error> { debug!("Sending {} to server", code); @@ -997,9 +1002,9 @@ where // Read all data the server has to offer, which can be multiple messages // buffered in 8196 bytes chunks. loop { - let response = self.receive_server_message(server, &address, &pool).await?; + self.receive_server_message(server, &address, &pool, server_message_buffer).await?; - match write_all_half(&mut self.write, &response).await { + match write_all_half(&mut self.write, &server_message_buffer).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -1007,6 +1012,8 @@ where } }; + server_message_buffer.clear(); + if !server.is_data_available() { break; } @@ -1043,16 +1050,17 @@ where server: &mut Server, address: &Address, pool: &ConnectionPool, - ) -> Result { + server_message_buffer: &mut BytesMut, + ) -> Result<(), Error> { if pool.settings.user.statement_timeout > 0 { match tokio::time::timeout( tokio::time::Duration::from_millis(pool.settings.user.statement_timeout), - server.recv(), + server.recv(server_message_buffer), ) .await { Ok(result) => match result { - Ok(message) => Ok(message), + Ok(_) => Ok(()), Err(err) => { pool.ban(address, self.process_id); error_response_terminal( @@ -1075,8 +1083,8 @@ where } } } else { - match server.recv().await { - Ok(message) => Ok(message), + match server.recv(server_message_buffer).await { + Ok(_) => Ok(()), Err(err) => { pool.ban(address, self.process_id); error_response_terminal( diff --git a/src/server.rs b/src/server.rs index 899c5004..5d0b7aaf 100644 --- a/src/server.rs +++ b/src/server.rs @@ -391,9 +391,7 @@ impl Server { /// Receive data from the server in response to a client request. /// This method must be called multiple times while `self.is_data_available()` is true /// in order to receive all data the server has to offer. - pub async fn recv(&mut self) -> Result { - // Our server response buffer. We buffer data before we give it to the client. - let mut message_buffer = BytesMut::with_capacity(8196); + pub async fn recv(&mut self, server_message_buffer: &mut BytesMut) -> Result<(), Error> { loop { let mut message = match read_message(&mut self.read).await { @@ -406,7 +404,7 @@ impl Server { }; // Buffer the message we'll forward to the client later. - message_buffer.put(&message[..]); + server_message_buffer.put(&message[..]); let code = message.get_u8() as char; let _len = message.get_i32(); @@ -486,7 +484,7 @@ impl Server { self.data_available = true; // Don't flush yet, the more we buffer, the faster this goes...up to a limit. - if message_buffer.len() >= 8196 { + if server_message_buffer.len() >= 8196 { break; } } @@ -516,13 +514,12 @@ impl Server { // Keep track of how much data we got from the server for stats. self.stats - .data_received(message_buffer.len(), self.server_id); + .data_received(server_message_buffer.len(), self.server_id); // Successfully received data from server self.last_activity = SystemTime::now(); - // Pass the data back to the client. - Ok(message_buffer) + Ok(()) } /// If the server is still inside a transaction. @@ -577,8 +574,10 @@ impl Server { self.send(&query).await?; + let mut server_message_buffer = BytesMut::with_capacity(8196); + loop { - let _ = self.recv().await?; + let _ = self.recv(&mut server_message_buffer).await?; if !self.data_available { break; From 09db31db4a0a7168fac68d18af820edc53e8b0ba Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 17 Oct 2022 10:54:26 -0700 Subject: [PATCH 06/35] fmt --- src/client.rs | 36 +++++++++++++++++++++++++++++------- src/server.rs | 1 - 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/client.rs b/src/client.rs index 7eb7de22..efb49b89 100644 --- a/src/client.rs +++ b/src/client.rs @@ -836,8 +836,15 @@ where 'Q' => { debug!("Sending query to server"); - self.send_and_receive_loop(code, &message, server, &address, &pool, &mut server_message_buffer) - .await?; + self.send_and_receive_loop( + code, + &message, + server, + &address, + &pool, + &mut server_message_buffer, + ) + .await?; if !server.in_transaction() { // Report transaction executed statistics. @@ -890,7 +897,8 @@ where client_message_buffer.put(&message[..]); - let first_message_code = (*client_message_buffer.get(0).unwrap_or(&0)) as char; + let first_message_code = + (*client_message_buffer.get(0).unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' { @@ -906,8 +914,15 @@ where } } - self.send_and_receive_loop(code, &client_message_buffer, server, &address, &pool, &mut server_message_buffer) - .await?; + self.send_and_receive_loop( + code, + &client_message_buffer, + server, + &address, + &pool, + &mut server_message_buffer, + ) + .await?; client_message_buffer.clear(); @@ -936,7 +951,13 @@ where self.send_server_message(server, &message, &address, &pool) .await?; - self.receive_server_message(server, &address, &pool, &mut server_message_buffer).await?; + self.receive_server_message( + server, + &address, + &pool, + &mut server_message_buffer, + ) + .await?; match write_all_half(&mut self.write, &server_message_buffer).await { Ok(_) => (), @@ -1002,7 +1023,8 @@ where // Read all data the server has to offer, which can be multiple messages // buffered in 8196 bytes chunks. loop { - self.receive_server_message(server, &address, &pool, server_message_buffer).await?; + self.receive_server_message(server, &address, &pool, server_message_buffer) + .await?; match write_all_half(&mut self.write, &server_message_buffer).await { Ok(_) => (), diff --git a/src/server.rs b/src/server.rs index 5d0b7aaf..964f8d14 100644 --- a/src/server.rs +++ b/src/server.rs @@ -392,7 +392,6 @@ impl Server { /// This method must be called multiple times while `self.is_data_available()` is true /// in order to receive all data the server has to offer. pub async fn recv(&mut self, server_message_buffer: &mut BytesMut) -> Result<(), Error> { - loop { let mut message = match read_message(&mut self.read).await { Ok(message) => message, From 8b47f32ae44a9013c46fa7e08eddf1505128b8ef Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 14 Nov 2022 20:51:19 -0500 Subject: [PATCH 07/35] Read message directly onto buffer instead of new bytesmut --- src/client.rs | 102 ++++++++++++++++++++++++++++++------------------ src/messages.rs | 15 ++++--- src/server.rs | 16 ++++---- 3 files changed, 81 insertions(+), 52 deletions(-) diff --git a/src/client.rs b/src/client.rs index efb49b89..3e264465 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,6 +2,7 @@ use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace, warn}; use std::collections::HashMap; +use std::io::Cursor; use std::time::Instant; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; @@ -586,6 +587,8 @@ where let mut server_message_buffer = BytesMut::with_capacity(8196); + let mut clear_client_message_buffer = false; + // Our custom protocol loop. // We expect the client to either start a transaction with regular queries // or issue commands for our sharding and server selection protocol. @@ -595,13 +598,19 @@ where self.transaction_mode ); + if clear_client_message_buffer { + client_message_buffer.clear(); + } + + clear_client_message_buffer = true; + // Read a complete message from the client, which normally would be // either a `Q` (query) or `P` (prepare, extended protocol). // We can parse it here before grabbing a server from the pool, // in case the client is sending some custom protocol messages, e.g. // SET SHARDING KEY TO 'bigint'; - let message = tokio::select! { + let message_start = tokio::select! { _ = self.shutdown.recv() => { if !self.admin { error_response_terminal( @@ -613,13 +622,18 @@ where // Admin clients ignore shutdown. else { - read_message(&mut self.read).await? + read_message(&mut self.read, &mut client_message_buffer).await? } }, - message_result = read_message(&mut self.read) => message_result? + message_result = read_message(&mut self.read, &mut client_message_buffer) => message_result? }; - match message[0] as char { + let mut message_cursor = Cursor::new(&client_message_buffer); + message_cursor.advance(message_start); + + let message_code = message_cursor.get_u8(); + + match message_code as char { // Buffer extended protocol messages even if we do not have // a server connection yet. Hopefully, when we get the S message // we'll be able to allocate a connection. Also, clients do not expect @@ -628,7 +642,8 @@ where // to the client so we buffer them and defer the decision to error out or not // to when we get the S message 'P' | 'B' | 'D' | 'E' => { - client_message_buffer.put(&message[..]); + // client_message_buffer.put(&message[..]); + clear_client_message_buffer = false; continue; } 'X' => { @@ -641,7 +656,12 @@ where // Handle admin database queries. if self.admin { debug!("Handling admin command"); - handle_admin(&mut self.write, message, self.client_server_map.clone()).await?; + handle_admin( + &mut self.write, + client_message_buffer.clone(), + self.client_server_map.clone(), + ) + .await?; continue; } @@ -668,11 +688,11 @@ where let current_shard = query_router.shard(); // Handle all custom protocol commands, if any. - match query_router.try_execute_command(message.clone()) { + match query_router.try_execute_command(client_message_buffer.clone()) { // Normal query, not a custom command. None => { if query_router.query_parser_enabled() { - query_router.infer_role(message.clone()); + query_router.infer_role(client_message_buffer.clone()); } } @@ -752,7 +772,7 @@ where // but we were unable to grab a connection from the pool // We'll send back an error message and clean the extended // protocol buffer - if message[0] as char == 'S' { + if message_code as char == 'S' { error!("Got Sync message but failed to get a connection from the pool"); client_message_buffer.clear(); } @@ -792,7 +812,7 @@ where // Set application_name. server.set_name(&self.application_name).await?; - let mut initial_message = Some(message); + let mut initial_message_start = Some(message_start); // Transaction loop. Multiple queries can be issued by the client here. // The connection belongs to the client until the transaction is over, @@ -801,11 +821,11 @@ where // If the client is in session mode, no more custom protocol // commands will be accepted. loop { - let message = match initial_message { + let message_start = match initial_message_start { None => { trace!("Waiting for message inside transaction or in session mode"); - match read_message(&mut self.read).await { + match read_message(&mut self.read, &mut client_message_buffer).await { Ok(message) => message, Err(err) => { // Client disconnected inside a transaction. @@ -816,18 +836,18 @@ where } } } - Some(message) => { - initial_message = None; - message + Some(message_start) => { + initial_message_start = None; + message_start } }; + let mut message_cursor = Cursor::new(&client_message_buffer); + message_cursor.advance(message_start); + // The message will be forwarded to the server intact. We still would like to // parse it below to figure out what to do with it. - - // Safe to unwrap because we know this message has a certain length and has the code - // This reads the first byte without advancing the internal pointer and mutating the bytes - let code = *message.get(0).unwrap() as char; + let code = message_cursor.get_u8() as char; trace!("Message: {}", code); @@ -838,7 +858,7 @@ where self.send_and_receive_loop( code, - &message, + &mut client_message_buffer, server, &address, &pool, @@ -869,25 +889,21 @@ where // Parse // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. 'P' => { - client_message_buffer.put(&message[..]); } // Bind // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' 'B' => { - client_message_buffer.put(&message[..]); } // Describe // Command a client can issue to describe a previously prepared named statement. 'D' => { - client_message_buffer.put(&message[..]); } // Execute // Execute a prepared statement prepared in `P` and bound in `B`. 'E' => { - client_message_buffer.put(&message[..]); } // Sync @@ -895,8 +911,6 @@ where 'S' => { debug!("Sending query to server"); - client_message_buffer.put(&message[..]); - let first_message_code = (*client_message_buffer.get(0).unwrap_or(&0)) as char; @@ -916,7 +930,7 @@ where self.send_and_receive_loop( code, - &client_message_buffer, + &mut client_message_buffer, server, &address, &pool, @@ -924,8 +938,6 @@ where ) .await?; - client_message_buffer.clear(); - if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -941,15 +953,25 @@ where 'd' => { // Forward the data to the server, // don't buffer it since it can be rather large. - self.send_server_message(server, &message, &address, &pool) - .await?; + self.send_client_message_to_server( + server, + &mut client_message_buffer, + &address, + &pool, + ) + .await?; } // CopyDone or CopyFail // Copy is done, successfully or not. 'c' | 'f' => { - self.send_server_message(server, &message, &address, &pool) - .await?; + self.send_client_message_to_server( + server, + &mut client_message_buffer, + &address, + &pool, + ) + .await?; self.receive_server_message( server, @@ -1008,7 +1030,7 @@ where async fn send_and_receive_loop( &mut self, code: char, - message: &BytesMut, + message: &mut BytesMut, server: &mut Server, address: &Address, pool: &ConnectionPool, @@ -1016,7 +1038,7 @@ where ) -> Result<(), Error> { debug!("Sending {} to server", code); - self.send_server_message(server, message, &address, &pool) + self.send_client_message_to_server(server, message, &address, &pool) .await?; let query_start = Instant::now(); @@ -1051,17 +1073,21 @@ where Ok(()) } - async fn send_server_message( + async fn send_client_message_to_server( &self, server: &mut Server, - message: &BytesMut, + message: &mut BytesMut, address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { match server.send(message).await { - Ok(_) => Ok(()), + Ok(_) => { + message.clear(); + Ok(()) + } Err(err) => { pool.ban(address, self.process_id); + message.clear(); Err(err) } } diff --git a/src/messages.rs b/src/messages.rs index 8fffcfb5..1c0dfb4b 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -466,10 +466,12 @@ where } /// Read a complete message from the socket. -pub async fn read_message(stream: &mut S) -> Result +pub async fn read_message(stream: &mut S, buffer: &mut BytesMut) -> Result where S: tokio::io::AsyncRead + std::marker::Unpin, { + let starting_point = buffer.len(); + let code = match stream.read_u8().await { Ok(code) => code, Err(_) => return Err(Error::SocketError), @@ -487,13 +489,14 @@ where Err(_) => return Err(Error::SocketError), }; - let mut bytes = BytesMut::with_capacity(len as usize + 1); + buffer.put_u8(code); + buffer.put_i32(len); + buffer.put_slice(&buf); - bytes.put_u8(code); - bytes.put_i32(len); - bytes.put_slice(&buf); + // let mut cursor = Cursor::new(buffer); + // cursor.advance(starting_point); - Ok(bytes) + Ok(starting_point) } pub fn server_parameter_message(key: &str, value: &str) -> BytesMut { diff --git a/src/server.rs b/src/server.rs index 964f8d14..78182758 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,7 +2,7 @@ /// Here we are pretending to the a Postgres client. use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace, warn}; -use std::io::Read; +use std::io::{Cursor, Read}; use std::time::SystemTime; use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::{ @@ -393,7 +393,7 @@ impl Server { /// in order to receive all data the server has to offer. pub async fn recv(&mut self, server_message_buffer: &mut BytesMut) -> Result<(), Error> { loop { - let mut message = match read_message(&mut self.read).await { + let message_start = match read_message(&mut self.read, server_message_buffer).await { Ok(message) => message, Err(err) => { error!("Terminating server because of: {:?}", err); @@ -402,18 +402,18 @@ impl Server { } }; - // Buffer the message we'll forward to the client later. - server_message_buffer.put(&message[..]); + let mut message_cursor = Cursor::new(&server_message_buffer); + message_cursor.advance(message_start); - let code = message.get_u8() as char; - let _len = message.get_i32(); + let code = message_cursor.get_u8() as char; + let _len = message_cursor.get_i32(); trace!("Message: {}", code); match code { // ReadyForQuery 'Z' => { - let transaction_state = message.get_u8() as char; + let transaction_state = message_cursor.get_u8() as char; match transaction_state { // In transaction. @@ -447,7 +447,7 @@ impl Server { // CommandComplete 'C' => { let mut command_tag = String::new(); - match message.reader().read_to_string(&mut command_tag) { + match message_cursor.reader().read_to_string(&mut command_tag) { Ok(_) => { // Non-exhaustive list of commands that are likely to change session variables/resources // which can leak between clients. This is a best effort to block bad clients From 1589548b271362e559a12518cfe38d86cc07bca1 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 14 Nov 2022 21:22:39 -0500 Subject: [PATCH 08/35] remove commented code --- src/messages.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/messages.rs b/src/messages.rs index ffc147d2..18cc65d9 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -493,9 +493,6 @@ where buffer.put_i32(len); buffer.put_slice(&buf); - // let mut cursor = Cursor::new(buffer); - // cursor.advance(starting_point); - Ok(starting_point) } From e2b2cb0bf05a8c9e57f193e37456de20661a121b Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 14:36:46 -0500 Subject: [PATCH 09/35] Move server buffer to server object --- src/client.rs | 21 +++++++-------------- src/server.rs | 23 +++++++++++++++-------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/client.rs b/src/client.rs index ec5fdd13..b3d28e20 100644 --- a/src/client.rs +++ b/src/client.rs @@ -582,8 +582,6 @@ where // them to the backend. let mut client_message_buffer = BytesMut::with_capacity(8196); - let mut server_message_buffer = BytesMut::with_capacity(8196); - let mut clear_client_message_buffer = false; // Our custom protocol loop. @@ -859,7 +857,6 @@ where server, &address, &pool, - &mut server_message_buffer, ) .await?; @@ -927,7 +924,6 @@ where server, &address, &pool, - &mut server_message_buffer, ) .await?; @@ -970,11 +966,10 @@ where server, &address, &pool, - &mut server_message_buffer, ) .await?; - match write_all_half(&mut self.write, &server_message_buffer).await { + match write_all_half(&mut self.write, &server.server_message_buffer).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -982,7 +977,7 @@ where } }; - server_message_buffer.clear(); + server.clear_server_message_buffer(); if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -1027,7 +1022,6 @@ where server: &mut Server, address: &Address, pool: &ConnectionPool, - server_message_buffer: &mut BytesMut, ) -> Result<(), Error> { debug!("Sending {} to server", code); @@ -1038,10 +1032,10 @@ where // Read all data the server has to offer, which can be multiple messages // buffered in 8196 bytes chunks. loop { - self.receive_server_message(server, &address, &pool, server_message_buffer) + self.receive_server_message(server, &address, &pool) .await?; - match write_all_half(&mut self.write, &server_message_buffer).await { + match write_all_half(&mut self.write, &server.server_message_buffer).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -1049,7 +1043,7 @@ where } }; - server_message_buffer.clear(); + server.clear_server_message_buffer(); if !server.is_data_available() { break; @@ -1091,12 +1085,11 @@ where server: &mut Server, address: &Address, pool: &ConnectionPool, - server_message_buffer: &mut BytesMut, ) -> Result<(), Error> { if pool.settings.user.statement_timeout > 0 { match tokio::time::timeout( tokio::time::Duration::from_millis(pool.settings.user.statement_timeout), - server.recv(server_message_buffer), + server.recv(), ) .await { @@ -1124,7 +1117,7 @@ where } } } else { - match server.recv(server_message_buffer).await { + match server.recv().await { Ok(_) => Ok(()), Err(err) => { pool.ban(address, self.process_id); diff --git a/src/server.rs b/src/server.rs index 4f220213..fb8f65c6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -65,6 +65,8 @@ pub struct Server { // Last time that a successful server send or response happened last_activity: SystemTime, + + pub server_message_buffer: BytesMut } impl Server { @@ -313,6 +315,7 @@ impl Server { address: address.clone(), read: BufReader::new(read), write: write, + server_message_buffer: BytesMut::with_capacity(8196), server_info: server_info, server_id: server_id, process_id: process_id, @@ -391,9 +394,9 @@ impl Server { /// Receive data from the server in response to a client request. /// This method must be called multiple times while `self.is_data_available()` is true /// in order to receive all data the server has to offer. - pub async fn recv(&mut self, server_message_buffer: &mut BytesMut) -> Result<(), Error> { + pub async fn recv(&mut self) -> Result<(), Error> { loop { - let message_start = match read_message(&mut self.read, server_message_buffer).await { + let message_start = match read_message(&mut self.read, &mut self.server_message_buffer).await { Ok(message) => message, Err(err) => { error!("Terminating server because of: {:?}", err); @@ -402,7 +405,7 @@ impl Server { } }; - let mut message_cursor = Cursor::new(&server_message_buffer); + let mut message_cursor = Cursor::new(&self.server_message_buffer); message_cursor.advance(message_start); let code = message_cursor.get_u8() as char; @@ -483,7 +486,7 @@ impl Server { self.data_available = true; // Don't flush yet, the more we buffer, the faster this goes...up to a limit. - if server_message_buffer.len() >= 8196 { + if self.server_message_buffer.len() >= 8196 { break; } } @@ -513,7 +516,7 @@ impl Server { // Keep track of how much data we got from the server for stats. self.stats - .data_received(server_message_buffer.len(), self.server_id); + .data_received(self.server_message_buffer.len(), self.server_id); // Successfully received data from server self.last_activity = SystemTime::now(); @@ -573,16 +576,16 @@ impl Server { self.send(&query).await?; - let mut server_message_buffer = BytesMut::with_capacity(8196); - loop { - let _ = self.recv(&mut server_message_buffer).await?; + let _ = self.recv().await?; if !self.data_available { break; } } + self.clear_server_message_buffer(); + Ok(()) } @@ -651,6 +654,10 @@ impl Server { pub fn mark_dirty(&mut self) { self.needs_cleanup = true; } + + pub fn clear_server_message_buffer(&mut self) { + self.server_message_buffer.clear(); + } } impl Drop for Server { From 8d68d22ad36adb7af66d7848a0e1a6e245f6da71 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 14:54:06 -0500 Subject: [PATCH 10/35] Move client and server message buffers to attribute of object --- src/client.rs | 95 ++++++++++++++++++--------------------------------- src/server.rs | 22 ++++++------ 2 files changed, 46 insertions(+), 71 deletions(-) diff --git a/src/client.rs b/src/client.rs index b3d28e20..c9fef4aa 100644 --- a/src/client.rs +++ b/src/client.rs @@ -38,6 +38,10 @@ pub struct Client { /// better than a stock buffer. write: T, + /// Internal buffer, where we place messages until we have to flush + /// them to the backend. + client_message_buffer: BytesMut, + /// Address addr: std::net::SocketAddr, @@ -487,6 +491,7 @@ where Ok(Client { read: BufReader::new(read), write, + client_message_buffer: BytesMut::with_capacity(8196), addr, cancel_mode: false, transaction_mode, @@ -520,6 +525,7 @@ where Ok(Client { read: BufReader::new(read), write, + client_message_buffer: BytesMut::with_capacity(8196), addr, cancel_mode: true, transaction_mode: false, @@ -578,10 +584,6 @@ where self.application_name.clone(), ); - // Internal buffer, where we place messages until we have to flush - // them to the backend. - let mut client_message_buffer = BytesMut::with_capacity(8196); - let mut clear_client_message_buffer = false; // Our custom protocol loop. @@ -594,7 +596,7 @@ where ); if clear_client_message_buffer { - client_message_buffer.clear(); + self.client_message_buffer.clear(); } clear_client_message_buffer = true; @@ -617,13 +619,13 @@ where // Admin clients ignore shutdown. else { - read_message(&mut self.read, &mut client_message_buffer).await? + read_message(&mut self.read, &mut self.client_message_buffer).await? } }, - message_result = read_message(&mut self.read, &mut client_message_buffer) => message_result? + message_result = read_message(&mut self.read, &mut self.client_message_buffer) => message_result? }; - let mut message_cursor = Cursor::new(&client_message_buffer); + let mut message_cursor = Cursor::new(&self.client_message_buffer); message_cursor.advance(message_start); let message_code = message_cursor.get_u8(); @@ -653,7 +655,7 @@ where debug!("Handling admin command"); handle_admin( &mut self.write, - client_message_buffer.clone(), + self.client_message_buffer.clone(), self.client_server_map.clone(), ) .await?; @@ -683,11 +685,11 @@ where let current_shard = query_router.shard(); // Handle all custom protocol commands, if any. - match query_router.try_execute_command(client_message_buffer.clone()) { + match query_router.try_execute_command(self.client_message_buffer.clone()) { // Normal query, not a custom command. None => { if query_router.query_parser_enabled() { - query_router.infer(client_message_buffer.clone()); + query_router.infer(self.client_message_buffer.clone()); } } @@ -769,7 +771,7 @@ where // protocol buffer if message_code as char == 'S' { error!("Got Sync message but failed to get a connection from the pool"); - client_message_buffer.clear(); + self.client_message_buffer.clear(); } error_response(&mut self.write, "could not get connection from the pool") .await?; @@ -820,7 +822,7 @@ where None => { trace!("Waiting for message inside transaction or in session mode"); - match read_message(&mut self.read, &mut client_message_buffer).await { + match read_message(&mut self.read, &mut self.client_message_buffer).await { Ok(message) => message, Err(err) => { // Client disconnected inside a transaction. @@ -837,7 +839,7 @@ where } }; - let mut message_cursor = Cursor::new(&client_message_buffer); + let mut message_cursor = Cursor::new(&self.client_message_buffer); message_cursor.advance(message_start); // The message will be forwarded to the server intact. We still would like to @@ -851,14 +853,8 @@ where 'Q' => { debug!("Sending query to server"); - self.send_and_receive_loop( - code, - &mut client_message_buffer, - server, - &address, - &pool, - ) - .await?; + self.send_and_receive_loop(code, server, &address, &pool) + .await?; if !server.in_transaction() { // Report transaction executed statistics. @@ -902,7 +898,7 @@ where debug!("Sending query to server"); let first_message_code = - (*client_message_buffer.get(0).unwrap_or(&0)) as char; + (*self.client_message_buffer.get(0).unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' { @@ -910,7 +906,8 @@ where // P followed by 32 int followed by null-terminated statement name // So message code should be in offset 0 of the buffer, first character // in prepared statement name would be index 5 - let first_char_in_name = *client_message_buffer.get(5).unwrap_or(&0); + let first_char_in_name = + *self.client_message_buffer.get(5).unwrap_or(&0); if first_char_in_name != 0 { // This is a named prepared statement // Server connection state will need to be cleared at checkin @@ -918,14 +915,8 @@ where } } - self.send_and_receive_loop( - code, - &mut client_message_buffer, - server, - &address, - &pool, - ) - .await?; + self.send_and_receive_loop(code, server, &address, &pool) + .await?; if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -942,32 +933,17 @@ where 'd' => { // Forward the data to the server, // don't buffer it since it can be rather large. - self.send_client_message_to_server( - server, - &mut client_message_buffer, - &address, - &pool, - ) - .await?; + self.send_client_message_to_server(server, &address, &pool) + .await?; } // CopyDone or CopyFail // Copy is done, successfully or not. 'c' | 'f' => { - self.send_client_message_to_server( - server, - &mut client_message_buffer, - &address, - &pool, - ) - .await?; + self.send_client_message_to_server(server, &address, &pool) + .await?; - self.receive_server_message( - server, - &address, - &pool, - ) - .await?; + self.receive_server_message(server, &address, &pool).await?; match write_all_half(&mut self.write, &server.server_message_buffer).await { Ok(_) => (), @@ -1018,22 +994,20 @@ where async fn send_and_receive_loop( &mut self, code: char, - message: &mut BytesMut, server: &mut Server, address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { debug!("Sending {} to server", code); - self.send_client_message_to_server(server, message, &address, &pool) + self.send_client_message_to_server(server, &address, &pool) .await?; let query_start = Instant::now(); // Read all data the server has to offer, which can be multiple messages // buffered in 8196 bytes chunks. loop { - self.receive_server_message(server, &address, &pool) - .await?; + self.receive_server_message(server, &address, &pool).await?; match write_all_half(&mut self.write, &server.server_message_buffer).await { Ok(_) => (), @@ -1061,20 +1035,19 @@ where } async fn send_client_message_to_server( - &self, + &mut self, server: &mut Server, - message: &mut BytesMut, address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { - match server.send(message).await { + match server.send(&self.client_message_buffer).await { Ok(_) => { - message.clear(); + self.client_message_buffer.clear(); Ok(()) } Err(err) => { pool.ban(address, self.process_id); - message.clear(); + self.client_message_buffer.clear(); Err(err) } } diff --git a/src/server.rs b/src/server.rs index fb8f65c6..a85ddfda 100644 --- a/src/server.rs +++ b/src/server.rs @@ -32,6 +32,9 @@ pub struct Server { /// Unbuffered write socket (our client code buffers). write: OwnedWriteHalf, + /// Our server response buffer. We buffer data before we give it to the client. + pub server_message_buffer: BytesMut, + /// Server information the server sent us over on startup. server_info: BytesMut, @@ -65,8 +68,6 @@ pub struct Server { // Last time that a successful server send or response happened last_activity: SystemTime, - - pub server_message_buffer: BytesMut } impl Server { @@ -396,14 +397,15 @@ impl Server { /// in order to receive all data the server has to offer. pub async fn recv(&mut self) -> Result<(), Error> { loop { - let message_start = match read_message(&mut self.read, &mut self.server_message_buffer).await { - Ok(message) => message, - Err(err) => { - error!("Terminating server because of: {:?}", err); - self.bad = true; - return Err(err); - } - }; + let message_start = + match read_message(&mut self.read, &mut self.server_message_buffer).await { + Ok(message) => message, + Err(err) => { + error!("Terminating server because of: {:?}", err); + self.bad = true; + return Err(err); + } + }; let mut message_cursor = Cursor::new(&self.server_message_buffer); message_cursor.advance(message_start); From f019e28afac05cbdb4d790490aefdefd8630cf70 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 15:00:15 -0500 Subject: [PATCH 11/35] Remove clear buffer function since public already --- src/client.rs | 6 +++--- src/server.rs | 6 +----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/client.rs b/src/client.rs index c9fef4aa..76c720fe 100644 --- a/src/client.rs +++ b/src/client.rs @@ -491,8 +491,8 @@ where Ok(Client { read: BufReader::new(read), write, - client_message_buffer: BytesMut::with_capacity(8196), addr, + client_message_buffer: BytesMut::with_capacity(8196), cancel_mode: false, transaction_mode, process_id, @@ -953,7 +953,7 @@ where } }; - server.clear_server_message_buffer(); + server.server_message_buffer.clear(); if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -1017,7 +1017,7 @@ where } }; - server.clear_server_message_buffer(); + server.server_message_buffer.clear(); if !server.is_data_available() { break; diff --git a/src/server.rs b/src/server.rs index a85ddfda..fc07ea09 100644 --- a/src/server.rs +++ b/src/server.rs @@ -586,7 +586,7 @@ impl Server { } } - self.clear_server_message_buffer(); + self.server_message_buffer.clear(); Ok(()) } @@ -656,10 +656,6 @@ impl Server { pub fn mark_dirty(&mut self) { self.needs_cleanup = true; } - - pub fn clear_server_message_buffer(&mut self) { - self.server_message_buffer.clear(); - } } impl Drop for Server { From 4cb38b83c2986be991ea2f71c0686512c58a766a Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 15:09:47 -0500 Subject: [PATCH 12/35] Remove commented code --- src/client.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/client.rs b/src/client.rs index 76c720fe..789bf4a6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -639,7 +639,6 @@ where // to the client so we buffer them and defer the decision to error out or not // to when we get the S message 'P' | 'B' | 'D' | 'E' => { - // client_message_buffer.put(&message[..]); clear_client_message_buffer = false; continue; } From 8f3fc253d17a071f8524749dd1aca0b1dfc8f36d Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 16:08:26 -0500 Subject: [PATCH 13/35] Remove cloning operation for query router --- src/client.rs | 4 +-- src/errors.rs | 1 + src/lib.rs | 21 +++++++++++++ src/query_router.rs | 77 ++++++++++++++++++++++++--------------------- 4 files changed, 66 insertions(+), 37 deletions(-) diff --git a/src/client.rs b/src/client.rs index 789bf4a6..8b5d4b9f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -684,11 +684,11 @@ where let current_shard = query_router.shard(); // Handle all custom protocol commands, if any. - match query_router.try_execute_command(self.client_message_buffer.clone()) { + match query_router.try_execute_command(&self.client_message_buffer) { // Normal query, not a custom command. None => { if query_router.query_parser_enabled() { - query_router.infer(self.client_message_buffer.clone()); + query_router.infer(&self.client_message_buffer); } } diff --git a/src/errors.rs b/src/errors.rs index 50301f36..91331b94 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -13,4 +13,5 @@ pub enum Error { TlsError, StatementTimeout, ShuttingDown, + ParseBytesError(String), } diff --git a/src/lib.rs b/src/lib.rs index e9a683f3..32a5f737 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,8 @@ +use std::io::{BufRead, Cursor}; + +use bytes::BytesMut; +use errors::Error; + pub mod config; pub mod constants; pub mod errors; @@ -30,3 +35,19 @@ pub fn format_duration(duration: &chrono::Duration) -> String { days, hours, minutes, seconds, milliseconds ) } + +pub trait BytesMutReader { + fn read_string(&mut self) -> Result; +} + +impl BytesMutReader for Cursor<&BytesMut> { + fn read_string(&mut self) -> Result { + let mut buf = vec![]; + match self.read_until(b'\0', &mut buf) { + Ok(_) => {} + Err(err) => return Err(Error::ParseBytesError(err.to_string())), + }; + + Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()) + } +} diff --git a/src/query_router.rs b/src/query_router.rs index 552c358c..ff5466c7 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -14,6 +14,9 @@ use crate::pool::PoolSettings; use crate::sharding::Sharder; use std::collections::BTreeSet; +use std::io::Cursor; + +use pgcat::BytesMutReader; /// Regexes used to parse custom commands. const CUSTOM_SQL_REGEXES: [&str; 7] = [ @@ -107,16 +110,18 @@ impl QueryRouter { } /// Try to parse a command and execute it. - pub fn try_execute_command(&mut self, mut buf: BytesMut) -> Option<(Command, String)> { - let code = buf.get_u8() as char; + pub fn try_execute_command(&mut self, buf: &BytesMut) -> Option<(Command, String)> { + let mut message_cursor = Cursor::new(buf); + + let code = message_cursor.get_u8() as char; // Only simple protocol supported for commands. if code != 'Q' { return None; } - let len = buf.get_i32() as usize; - let query = String::from_utf8_lossy(&buf[..len - 5]).to_string(); // Ignore the terminating NULL. + let _len = message_cursor.get_i32() as usize; + let query = message_cursor.read_string().unwrap(); let regex_set = match CUSTOM_SQL_REGEX_SET.get() { Some(regex_set) => regex_set, @@ -256,16 +261,18 @@ impl QueryRouter { } /// Try to infer which server to connect to based on the contents of the query. - pub fn infer(&mut self, mut buf: BytesMut) -> bool { + pub fn infer(&mut self, buf: &BytesMut) -> bool { debug!("Inferring role"); - let code = buf.get_u8() as char; - let len = buf.get_i32() as usize; + let mut message_cursor = Cursor::new(buf); + + let code = message_cursor.get_u8() as char; + let _len = message_cursor.get_i32() as usize; let query = match code { // Query 'Q' => { - let query = String::from_utf8_lossy(&buf[..len - 5]).to_string(); + let query = message_cursor.read_string().unwrap(); debug!("Query: '{}'", query); query } @@ -519,10 +526,10 @@ mod test { fn test_infer_replica() { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); + assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr.query_parser_enabled()); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); let queries = vec![ simple_query("SELECT * FROM items WHERE id = 5"), @@ -534,7 +541,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Replica)); } } @@ -553,7 +560,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Primary)); } } @@ -563,9 +570,9 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SELECT * FROM items WHERE id = 5"); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), None); } @@ -573,8 +580,8 @@ mod test { fn test_infer_parse_prepared() { QueryRouter::setup(); let mut qr = QueryRouter::new(); - qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); let prepared_stmt = BytesMut::from( &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], @@ -585,7 +592,7 @@ mod test { res.put(prepared_stmt); res.put_i16(0); - assert!(qr.infer(res)); + assert!(qr.infer(&res)); assert_eq!(qr.role(), Some(Role::Replica)); } @@ -668,7 +675,7 @@ mod test { // SetShardingKey let query = simple_query("SET SHARDING KEY TO 13"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetShardingKey, String::from("0"))) ); assert_eq!(qr.shard(), 0); @@ -676,7 +683,7 @@ mod test { // SetShard let query = simple_query("SET SHARD TO '1'"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetShard, String::from("1"))) ); assert_eq!(qr.shard(), 1); @@ -684,7 +691,7 @@ mod test { // ShowShard let query = simple_query("SHOW SHARD"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::ShowShard, String::from("1"))) ); @@ -702,7 +709,7 @@ mod test { for (idx, role) in roles.iter().enumerate() { let query = simple_query(&format!("SET SERVER ROLE TO '{}'", role)); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetServerRole, String::from(*role))) ); assert_eq!(qr.role(), verify_roles[idx],); @@ -711,7 +718,7 @@ mod test { // ShowServerRole let query = simple_query("SHOW SERVER ROLE"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::ShowServerRole, String::from(*role))) ); } @@ -721,14 +728,14 @@ mod test { for (idx, primary_reads) in primary_reads.iter().enumerate() { assert_eq!( - qr.try_execute_command(simple_query(&format!( + qr.try_execute_command(&simple_query(&format!( "SET PRIMARY READS TO {}", primary_reads ))), Some((Command::SetPrimaryReads, String::from(*primary_reads))) ); assert_eq!( - qr.try_execute_command(simple_query("SHOW PRIMARY READS")), + qr.try_execute_command(&simple_query("SHOW PRIMARY READS")), Some(( Command::ShowPrimaryReads, String::from(primary_reads_enabled[idx]) @@ -742,23 +749,23 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SET SERVER ROLE TO 'auto'"); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); - assert!(qr.try_execute_command(query) != None); + assert!(qr.try_execute_command(&query) != None); assert!(qr.query_parser_enabled()); assert_eq!(qr.role(), None); let query = simple_query("INSERT INTO test_table VALUES (1)"); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Primary)); let query = simple_query("SELECT * FROM test_table"); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Replica)); assert!(qr.query_parser_enabled()); let query = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(query) != None); + assert!(qr.try_execute_command(&query) != None); assert!(!qr.query_parser_enabled()); } @@ -791,16 +798,16 @@ mod test { assert!(!qr.primary_reads_enabled()); let q1 = simple_query("SET SERVER ROLE TO 'primary'"); - assert!(qr.try_execute_command(q1) != None); + assert!(qr.try_execute_command(&q1) != None); assert_eq!(qr.active_role.unwrap(), Role::Primary); let q2 = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(q2) != None); + assert!(qr.try_execute_command(&q2) != None); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); // Here we go :) let q3 = simple_query("SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)"); - assert!(qr.infer(q3)); + assert!(qr.infer(&q3)); assert_eq!(qr.shard(), 1); } @@ -809,13 +816,13 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.infer(simple_query("BEGIN; SELECT 1; COMMIT;"))); + assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;"))); assert_eq!(qr.role(), Role::Primary); - assert!(qr.infer(simple_query("SELECT 1; SELECT 2;"))); + assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;"))); assert_eq!(qr.role(), Role::Replica); - assert!(qr.infer(simple_query( + assert!(qr.infer(&simple_query( "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;" ))); assert_eq!(qr.role(), Role::Primary); From 728f6cffac1a363d7ea48940e2206ff09ce4a3f2 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 16:26:11 -0500 Subject: [PATCH 14/35] Rename server and client buffers to buffer --- src/client.rs | 52 +++++++++++++++++++++++++-------------------------- src/server.rs | 14 +++++++------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/client.rs b/src/client.rs index 8b5d4b9f..26a182a8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -40,7 +40,7 @@ pub struct Client { /// Internal buffer, where we place messages until we have to flush /// them to the backend. - client_message_buffer: BytesMut, + buffer: BytesMut, /// Address addr: std::net::SocketAddr, @@ -492,7 +492,7 @@ where read: BufReader::new(read), write, addr, - client_message_buffer: BytesMut::with_capacity(8196), + buffer: BytesMut::with_capacity(8196), cancel_mode: false, transaction_mode, process_id, @@ -525,7 +525,7 @@ where Ok(Client { read: BufReader::new(read), write, - client_message_buffer: BytesMut::with_capacity(8196), + buffer: BytesMut::with_capacity(8196), addr, cancel_mode: true, transaction_mode: false, @@ -584,7 +584,7 @@ where self.application_name.clone(), ); - let mut clear_client_message_buffer = false; + let mut clear_buffer = false; // Our custom protocol loop. // We expect the client to either start a transaction with regular queries @@ -595,11 +595,11 @@ where self.transaction_mode ); - if clear_client_message_buffer { - self.client_message_buffer.clear(); + if clear_buffer { + self.buffer.clear(); } - clear_client_message_buffer = true; + clear_buffer = true; // Read a complete message from the client, which normally would be // either a `Q` (query) or `P` (prepare, extended protocol). @@ -619,13 +619,13 @@ where // Admin clients ignore shutdown. else { - read_message(&mut self.read, &mut self.client_message_buffer).await? + read_message(&mut self.read, &mut self.buffer).await? } }, - message_result = read_message(&mut self.read, &mut self.client_message_buffer) => message_result? + message_result = read_message(&mut self.read, &mut self.buffer) => message_result? }; - let mut message_cursor = Cursor::new(&self.client_message_buffer); + let mut message_cursor = Cursor::new(&self.buffer); message_cursor.advance(message_start); let message_code = message_cursor.get_u8(); @@ -639,7 +639,7 @@ where // to the client so we buffer them and defer the decision to error out or not // to when we get the S message 'P' | 'B' | 'D' | 'E' => { - clear_client_message_buffer = false; + clear_buffer = false; continue; } 'X' => { @@ -654,7 +654,7 @@ where debug!("Handling admin command"); handle_admin( &mut self.write, - self.client_message_buffer.clone(), + self.buffer.clone(), self.client_server_map.clone(), ) .await?; @@ -684,11 +684,11 @@ where let current_shard = query_router.shard(); // Handle all custom protocol commands, if any. - match query_router.try_execute_command(&self.client_message_buffer) { + match query_router.try_execute_command(&self.buffer) { // Normal query, not a custom command. None => { if query_router.query_parser_enabled() { - query_router.infer(&self.client_message_buffer); + query_router.infer(&self.buffer); } } @@ -770,7 +770,7 @@ where // protocol buffer if message_code as char == 'S' { error!("Got Sync message but failed to get a connection from the pool"); - self.client_message_buffer.clear(); + self.buffer.clear(); } error_response(&mut self.write, "could not get connection from the pool") .await?; @@ -821,7 +821,7 @@ where None => { trace!("Waiting for message inside transaction or in session mode"); - match read_message(&mut self.read, &mut self.client_message_buffer).await { + match read_message(&mut self.read, &mut self.buffer).await { Ok(message) => message, Err(err) => { // Client disconnected inside a transaction. @@ -838,7 +838,7 @@ where } }; - let mut message_cursor = Cursor::new(&self.client_message_buffer); + let mut message_cursor = Cursor::new(&self.buffer); message_cursor.advance(message_start); // The message will be forwarded to the server intact. We still would like to @@ -897,7 +897,7 @@ where debug!("Sending query to server"); let first_message_code = - (*self.client_message_buffer.get(0).unwrap_or(&0)) as char; + (*self.buffer.get(0).unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' { @@ -906,7 +906,7 @@ where // So message code should be in offset 0 of the buffer, first character // in prepared statement name would be index 5 let first_char_in_name = - *self.client_message_buffer.get(5).unwrap_or(&0); + *self.buffer.get(5).unwrap_or(&0); if first_char_in_name != 0 { // This is a named prepared statement // Server connection state will need to be cleared at checkin @@ -944,7 +944,7 @@ where self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, &server.server_message_buffer).await { + match write_all_half(&mut self.write, &server.buffer).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -952,7 +952,7 @@ where } }; - server.server_message_buffer.clear(); + server.buffer.clear(); if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -1008,7 +1008,7 @@ where loop { self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, &server.server_message_buffer).await { + match write_all_half(&mut self.write, &server.buffer).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -1016,7 +1016,7 @@ where } }; - server.server_message_buffer.clear(); + server.buffer.clear(); if !server.is_data_available() { break; @@ -1039,14 +1039,14 @@ where address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { - match server.send(&self.client_message_buffer).await { + match server.send(&self.buffer).await { Ok(_) => { - self.client_message_buffer.clear(); + self.buffer.clear(); Ok(()) } Err(err) => { pool.ban(address, self.process_id); - self.client_message_buffer.clear(); + self.buffer.clear(); Err(err) } } diff --git a/src/server.rs b/src/server.rs index fc07ea09..e253c3d7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -33,7 +33,7 @@ pub struct Server { write: OwnedWriteHalf, /// Our server response buffer. We buffer data before we give it to the client. - pub server_message_buffer: BytesMut, + pub buffer: BytesMut, /// Server information the server sent us over on startup. server_info: BytesMut, @@ -316,7 +316,7 @@ impl Server { address: address.clone(), read: BufReader::new(read), write: write, - server_message_buffer: BytesMut::with_capacity(8196), + buffer: BytesMut::with_capacity(8196), server_info: server_info, server_id: server_id, process_id: process_id, @@ -398,7 +398,7 @@ impl Server { pub async fn recv(&mut self) -> Result<(), Error> { loop { let message_start = - match read_message(&mut self.read, &mut self.server_message_buffer).await { + match read_message(&mut self.read, &mut self.buffer).await { Ok(message) => message, Err(err) => { error!("Terminating server because of: {:?}", err); @@ -407,7 +407,7 @@ impl Server { } }; - let mut message_cursor = Cursor::new(&self.server_message_buffer); + let mut message_cursor = Cursor::new(&self.buffer); message_cursor.advance(message_start); let code = message_cursor.get_u8() as char; @@ -488,7 +488,7 @@ impl Server { self.data_available = true; // Don't flush yet, the more we buffer, the faster this goes...up to a limit. - if self.server_message_buffer.len() >= 8196 { + if self.buffer.len() >= 8196 { break; } } @@ -518,7 +518,7 @@ impl Server { // Keep track of how much data we got from the server for stats. self.stats - .data_received(self.server_message_buffer.len(), self.server_id); + .data_received(self.buffer.len(), self.server_id); // Successfully received data from server self.last_activity = SystemTime::now(); @@ -586,7 +586,7 @@ impl Server { } } - self.server_message_buffer.clear(); + self.buffer.clear(); Ok(()) } From ee19dbcad01d4c70e2325c5722e1ef6e9e6f4d5f Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 16:50:12 -0500 Subject: [PATCH 15/35] Fix bug --- src/query_router.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/query_router.rs b/src/query_router.rs index ff5466c7..ee6c67ad 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -279,7 +279,8 @@ impl QueryRouter { // Parse (prepared statement) 'P' => { - let mut start = 0; + // Start after the code and len + let mut start = std::mem::size_of::() + std::mem::size_of::(); // Skip the name of the prepared statement. while buf[start] != 0 && start < buf.len() { From 027011d0749eba8c58a4544c009722437abf6d2b Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 16:50:35 -0500 Subject: [PATCH 16/35] Rename read_messages to read_messages_into_buffer --- src/client.rs | 6 +++--- src/messages.rs | 2 +- src/server.rs | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/client.rs b/src/client.rs index 26a182a8..5b13cf0b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -619,10 +619,10 @@ where // Admin clients ignore shutdown. else { - read_message(&mut self.read, &mut self.buffer).await? + read_message_into_buffer(&mut self.read, &mut self.buffer).await? } }, - message_result = read_message(&mut self.read, &mut self.buffer) => message_result? + message_result = read_message_into_buffer(&mut self.read, &mut self.buffer) => message_result? }; let mut message_cursor = Cursor::new(&self.buffer); @@ -821,7 +821,7 @@ where None => { trace!("Waiting for message inside transaction or in session mode"); - match read_message(&mut self.read, &mut self.buffer).await { + match read_message_into_buffer(&mut self.read, &mut self.buffer).await { Ok(message) => message, Err(err) => { // Client disconnected inside a transaction. diff --git a/src/messages.rs b/src/messages.rs index 18cc65d9..de25cb28 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -466,7 +466,7 @@ where } /// Read a complete message from the socket. -pub async fn read_message(stream: &mut S, buffer: &mut BytesMut) -> Result +pub async fn read_message_into_buffer(stream: &mut S, buffer: &mut BytesMut) -> Result where S: tokio::io::AsyncRead + std::marker::Unpin, { diff --git a/src/server.rs b/src/server.rs index e253c3d7..6295adc2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -398,7 +398,7 @@ impl Server { pub async fn recv(&mut self) -> Result<(), Error> { loop { let message_start = - match read_message(&mut self.read, &mut self.buffer).await { + match read_message_into_buffer(&mut self.read, &mut self.buffer).await { Ok(message) => message, Err(err) => { error!("Terminating server because of: {:?}", err); From f0ceb8b144cb724d1dcc9268969fa80072e86fdf Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 16:54:55 -0500 Subject: [PATCH 17/35] fmt --- src/client.rs | 6 ++---- src/lib.rs | 6 ++---- src/messages.rs | 5 ++++- src/server.rs | 3 +-- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/client.rs b/src/client.rs index 5b13cf0b..2c6e54e2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -896,8 +896,7 @@ where 'S' => { debug!("Sending query to server"); - let first_message_code = - (*self.buffer.get(0).unwrap_or(&0)) as char; + let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' { @@ -905,8 +904,7 @@ where // P followed by 32 int followed by null-terminated statement name // So message code should be in offset 0 of the buffer, first character // in prepared statement name would be index 5 - let first_char_in_name = - *self.buffer.get(5).unwrap_or(&0); + let first_char_in_name = *self.buffer.get(5).unwrap_or(&0); if first_char_in_name != 0 { // This is a named prepared statement // Server connection state will need to be cleared at checkin diff --git a/src/lib.rs b/src/lib.rs index 32a5f737..5ac54818 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,10 +44,8 @@ impl BytesMutReader for Cursor<&BytesMut> { fn read_string(&mut self) -> Result { let mut buf = vec![]; match self.read_until(b'\0', &mut buf) { - Ok(_) => {} + Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), Err(err) => return Err(Error::ParseBytesError(err.to_string())), - }; - - Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()) + } } } diff --git a/src/messages.rs b/src/messages.rs index de25cb28..12e2dd4f 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -466,7 +466,10 @@ where } /// Read a complete message from the socket. -pub async fn read_message_into_buffer(stream: &mut S, buffer: &mut BytesMut) -> Result +pub async fn read_message_into_buffer( + stream: &mut S, + buffer: &mut BytesMut, +) -> Result where S: tokio::io::AsyncRead + std::marker::Unpin, { diff --git a/src/server.rs b/src/server.rs index 6295adc2..7d866105 100644 --- a/src/server.rs +++ b/src/server.rs @@ -517,8 +517,7 @@ impl Server { } // Keep track of how much data we got from the server for stats. - self.stats - .data_received(self.buffer.len(), self.server_id); + self.stats.data_received(self.buffer.len(), self.server_id); // Successfully received data from server self.last_activity = SystemTime::now(); From 021a1088339e962ca8d31de923f95e106a352224 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 17:17:10 -0500 Subject: [PATCH 18/35] Rename to message_buffer to be more explicit --- src/client.rs | 47 +++++++++++++++++++++++++---------------------- src/server.rs | 14 +++++++------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/src/client.rs b/src/client.rs index 2c6e54e2..e72d8597 100644 --- a/src/client.rs +++ b/src/client.rs @@ -40,7 +40,7 @@ pub struct Client { /// Internal buffer, where we place messages until we have to flush /// them to the backend. - buffer: BytesMut, + message_buffer: BytesMut, /// Address addr: std::net::SocketAddr, @@ -492,7 +492,7 @@ where read: BufReader::new(read), write, addr, - buffer: BytesMut::with_capacity(8196), + message_buffer: BytesMut::with_capacity(8196), cancel_mode: false, transaction_mode, process_id, @@ -525,7 +525,7 @@ where Ok(Client { read: BufReader::new(read), write, - buffer: BytesMut::with_capacity(8196), + message_buffer: BytesMut::with_capacity(8196), addr, cancel_mode: true, transaction_mode: false, @@ -596,7 +596,7 @@ where ); if clear_buffer { - self.buffer.clear(); + self.message_buffer.clear(); } clear_buffer = true; @@ -619,13 +619,13 @@ where // Admin clients ignore shutdown. else { - read_message_into_buffer(&mut self.read, &mut self.buffer).await? + read_message_into_buffer(&mut self.read, &mut self.message_buffer).await? } }, - message_result = read_message_into_buffer(&mut self.read, &mut self.buffer) => message_result? + message_result = read_message_into_buffer(&mut self.read, &mut self.message_buffer) => message_result? }; - let mut message_cursor = Cursor::new(&self.buffer); + let mut message_cursor = Cursor::new(&self.message_buffer); message_cursor.advance(message_start); let message_code = message_cursor.get_u8(); @@ -654,7 +654,7 @@ where debug!("Handling admin command"); handle_admin( &mut self.write, - self.buffer.clone(), + self.message_buffer.clone(), self.client_server_map.clone(), ) .await?; @@ -684,11 +684,11 @@ where let current_shard = query_router.shard(); // Handle all custom protocol commands, if any. - match query_router.try_execute_command(&self.buffer) { + match query_router.try_execute_command(&self.message_buffer) { // Normal query, not a custom command. None => { if query_router.query_parser_enabled() { - query_router.infer(&self.buffer); + query_router.infer(&self.message_buffer); } } @@ -770,7 +770,7 @@ where // protocol buffer if message_code as char == 'S' { error!("Got Sync message but failed to get a connection from the pool"); - self.buffer.clear(); + self.message_buffer.clear(); } error_response(&mut self.write, "could not get connection from the pool") .await?; @@ -821,7 +821,9 @@ where None => { trace!("Waiting for message inside transaction or in session mode"); - match read_message_into_buffer(&mut self.read, &mut self.buffer).await { + match read_message_into_buffer(&mut self.read, &mut self.message_buffer) + .await + { Ok(message) => message, Err(err) => { // Client disconnected inside a transaction. @@ -838,7 +840,7 @@ where } }; - let mut message_cursor = Cursor::new(&self.buffer); + let mut message_cursor = Cursor::new(&self.message_buffer); message_cursor.advance(message_start); // The message will be forwarded to the server intact. We still would like to @@ -896,7 +898,8 @@ where 'S' => { debug!("Sending query to server"); - let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; + let first_message_code = + (*self.message_buffer.get(0).unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' { @@ -904,7 +907,7 @@ where // P followed by 32 int followed by null-terminated statement name // So message code should be in offset 0 of the buffer, first character // in prepared statement name would be index 5 - let first_char_in_name = *self.buffer.get(5).unwrap_or(&0); + let first_char_in_name = *self.message_buffer.get(5).unwrap_or(&0); if first_char_in_name != 0 { // This is a named prepared statement // Server connection state will need to be cleared at checkin @@ -942,7 +945,7 @@ where self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, &server.buffer).await { + match write_all_half(&mut self.write, &server.message_buffer).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -950,7 +953,7 @@ where } }; - server.buffer.clear(); + server.message_buffer.clear(); if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -1006,7 +1009,7 @@ where loop { self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, &server.buffer).await { + match write_all_half(&mut self.write, &server.message_buffer).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -1014,7 +1017,7 @@ where } }; - server.buffer.clear(); + server.message_buffer.clear(); if !server.is_data_available() { break; @@ -1037,14 +1040,14 @@ where address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { - match server.send(&self.buffer).await { + match server.send(&self.message_buffer).await { Ok(_) => { - self.buffer.clear(); + self.message_buffer.clear(); Ok(()) } Err(err) => { pool.ban(address, self.process_id); - self.buffer.clear(); + self.message_buffer.clear(); Err(err) } } diff --git a/src/server.rs b/src/server.rs index 7d866105..b45e2adf 100644 --- a/src/server.rs +++ b/src/server.rs @@ -33,7 +33,7 @@ pub struct Server { write: OwnedWriteHalf, /// Our server response buffer. We buffer data before we give it to the client. - pub buffer: BytesMut, + pub message_buffer: BytesMut, /// Server information the server sent us over on startup. server_info: BytesMut, @@ -316,7 +316,7 @@ impl Server { address: address.clone(), read: BufReader::new(read), write: write, - buffer: BytesMut::with_capacity(8196), + message_buffer: BytesMut::with_capacity(8196), server_info: server_info, server_id: server_id, process_id: process_id, @@ -398,7 +398,7 @@ impl Server { pub async fn recv(&mut self) -> Result<(), Error> { loop { let message_start = - match read_message_into_buffer(&mut self.read, &mut self.buffer).await { + match read_message_into_buffer(&mut self.read, &mut self.message_buffer).await { Ok(message) => message, Err(err) => { error!("Terminating server because of: {:?}", err); @@ -407,7 +407,7 @@ impl Server { } }; - let mut message_cursor = Cursor::new(&self.buffer); + let mut message_cursor = Cursor::new(&self.message_buffer); message_cursor.advance(message_start); let code = message_cursor.get_u8() as char; @@ -488,7 +488,7 @@ impl Server { self.data_available = true; // Don't flush yet, the more we buffer, the faster this goes...up to a limit. - if self.buffer.len() >= 8196 { + if self.message_buffer.len() >= 8196 { break; } } @@ -517,7 +517,7 @@ impl Server { } // Keep track of how much data we got from the server for stats. - self.stats.data_received(self.buffer.len(), self.server_id); + self.stats.data_received(self.message_buffer.len(), self.server_id); // Successfully received data from server self.last_activity = SystemTime::now(); @@ -585,7 +585,7 @@ impl Server { } } - self.buffer.clear(); + self.message_buffer.clear(); Ok(()) } From 24d36af054cab996fe2c43ef7c06f89b8e007920 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 18:01:22 -0500 Subject: [PATCH 19/35] Read message body directly onto buffer --- src/messages.rs | 19 ++++++++++++------- src/server.rs | 3 ++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/messages.rs b/src/messages.rs index 12e2dd4f..f2e6ae42 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -483,19 +483,24 @@ where let len = match stream.read_i32().await { Ok(len) => len, Err(_) => return Err(Error::SocketError), - }; + } as usize; + + buffer.put_u8(code); + buffer.put_i32(len.try_into().unwrap()); - let mut buf = vec![0u8; len as usize - 4]; + buffer.resize(buffer.len() + len - 4, b'0'); - match stream.read_exact(&mut buf).await { + match stream + .read( + &mut buffer[starting_point + mem::size_of::() + mem::size_of::() + ..starting_point + mem::size_of::() + mem::size_of::() + len - 4], + ) + .await + { Ok(_) => (), Err(_) => return Err(Error::SocketError), }; - buffer.put_u8(code); - buffer.put_i32(len); - buffer.put_slice(&buf); - Ok(starting_point) } diff --git a/src/server.rs b/src/server.rs index b45e2adf..06f53ca4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -517,7 +517,8 @@ impl Server { } // Keep track of how much data we got from the server for stats. - self.stats.data_received(self.message_buffer.len(), self.server_id); + self.stats + .data_received(self.message_buffer.len(), self.server_id); // Successfully received data from server self.last_activity = SystemTime::now(); From 8a58f8b40b9bfa8b7cf0e01f4d244488ded8b3f4 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 18:23:32 -0500 Subject: [PATCH 20/35] Fix bug --- src/messages.rs | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/messages.rs b/src/messages.rs index f2e6ae42..a7b857dd 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -483,23 +483,26 @@ where let len = match stream.read_i32().await { Ok(len) => len, Err(_) => return Err(Error::SocketError), - } as usize; + }; buffer.put_u8(code); buffer.put_i32(len.try_into().unwrap()); - buffer.resize(buffer.len() + len - 4, b'0'); - - match stream - .read( - &mut buffer[starting_point + mem::size_of::() + mem::size_of::() - ..starting_point + mem::size_of::() + mem::size_of::() + len - 4], - ) - .await - { - Ok(_) => (), - Err(_) => return Err(Error::SocketError), - }; + buffer.resize(buffer.len() + len as usize - 4, b'0'); + + // Reading onto buffer will stall when [start..end] is 0 + if len - 4 != 0 { + match stream + .read( + &mut buffer[starting_point + mem::size_of::() + mem::size_of::() + ..starting_point + mem::size_of::() + mem::size_of::() + len as usize - 4], + ) + .await + { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + } Ok(starting_point) } From 089348bd728b3f6fd1027c8185e6530d81020a57 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 18:28:35 -0500 Subject: [PATCH 21/35] fmt --- src/messages.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/messages.rs b/src/messages.rs index a7b857dd..742cc74b 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -495,7 +495,8 @@ where match stream .read( &mut buffer[starting_point + mem::size_of::() + mem::size_of::() - ..starting_point + mem::size_of::() + mem::size_of::() + len as usize - 4], + ..starting_point + mem::size_of::() + mem::size_of::() + len as usize + - 4], ) .await { From ecb9627df48e10ebeadd7730e8d9ebb2ea5e1df9 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 18:34:44 -0500 Subject: [PATCH 22/35] Unused try_into --- src/messages.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/messages.rs b/src/messages.rs index 742cc74b..35350b3a 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -486,7 +486,7 @@ where }; buffer.put_u8(code); - buffer.put_i32(len.try_into().unwrap()); + buffer.put_i32(len); buffer.resize(buffer.len() + len as usize - 4, b'0'); From a3e3801c8e3e58edba620e53de88fead665d1d02 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 18:46:12 -0500 Subject: [PATCH 23/35] use read_exact instead of read to guarantee number of bytes read --- src/messages.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/messages.rs b/src/messages.rs index 35350b3a..8b531225 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -493,7 +493,7 @@ where // Reading onto buffer will stall when [start..end] is 0 if len - 4 != 0 { match stream - .read( + .read_exact( &mut buffer[starting_point + mem::size_of::() + mem::size_of::() ..starting_point + mem::size_of::() + mem::size_of::() + len as usize - 4], From bece3b90028b03060ef6a3f939e5b443c304d746 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 19:20:32 -0500 Subject: [PATCH 24/35] Remove length check --- src/messages.rs | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/messages.rs b/src/messages.rs index 8b531225..9bd6d4ec 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -490,20 +490,17 @@ where buffer.resize(buffer.len() + len as usize - 4, b'0'); - // Reading onto buffer will stall when [start..end] is 0 - if len - 4 != 0 { - match stream - .read_exact( - &mut buffer[starting_point + mem::size_of::() + mem::size_of::() - ..starting_point + mem::size_of::() + mem::size_of::() + len as usize - - 4], - ) - .await - { - Ok(_) => (), - Err(_) => return Err(Error::SocketError), - }; - } + match stream + .read_exact( + &mut buffer[starting_point + mem::size_of::() + mem::size_of::() + ..starting_point + mem::size_of::() + mem::size_of::() + len as usize + - 4], + ) + .await + { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; Ok(starting_point) } From 54ed58efe070b2106c254939428d81eef3dfff8e Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 16 Nov 2022 19:23:04 -0500 Subject: [PATCH 25/35] fmt --- src/messages.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/messages.rs b/src/messages.rs index 9bd6d4ec..9a51846a 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -493,8 +493,7 @@ where match stream .read_exact( &mut buffer[starting_point + mem::size_of::() + mem::size_of::() - ..starting_point + mem::size_of::() + mem::size_of::() + len as usize - - 4], + ..starting_point + mem::size_of::() + mem::size_of::() + len as usize - 4], ) .await { From ccde8f016e420c50270281698f1fc9c69c456d00 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 21 Nov 2022 11:34:12 -0500 Subject: [PATCH 26/35] Add comment to read_string function and uses sizeof for byte lengths in read_message --- src/lib.rs | 3 +++ src/messages.rs | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5ac54818..a8e98fe1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,9 @@ pub trait BytesMutReader { } impl BytesMutReader for Cursor<&BytesMut> { + /// Should only be used when reading strings from the message protocol. + /// + /// Can be used to read multiple strings from the same message which are separated by the null byte fn read_string(&mut self) -> Result { let mut buf = vec![]; match self.read_until(b'\0', &mut buf) { diff --git a/src/messages.rs b/src/messages.rs index cc2056d3..578a7545 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -501,12 +501,13 @@ where buffer.put_u8(code); buffer.put_i32(len); - buffer.resize(buffer.len() + len as usize - 4, b'0'); + buffer.resize(buffer.len() + len as usize - mem::size_of::(), b'0'); match stream .read_exact( &mut buffer[starting_point + mem::size_of::() + mem::size_of::() - ..starting_point + mem::size_of::() + mem::size_of::() + len as usize - 4], + ..starting_point + mem::size_of::() + mem::size_of::() + len as usize + - mem::size_of::()], ) .await { From 6c5b5df4be906af9a56ae739cc3da4a6ac5f3e33 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 21 Nov 2022 11:34:21 -0500 Subject: [PATCH 27/35] fmt --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index a8e98fe1..43d52fba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,7 +42,7 @@ pub trait BytesMutReader { impl BytesMutReader for Cursor<&BytesMut> { /// Should only be used when reading strings from the message protocol. - /// + /// /// Can be used to read multiple strings from the same message which are separated by the null byte fn read_string(&mut self) -> Result { let mut buf = vec![]; From 7bf950362cd8a28ded7dfabc23ef16215b34caf8 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 21 Nov 2022 16:13:56 -0500 Subject: [PATCH 28/35] Refactor reading query in infer role --- src/query_router.rs | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/query_router.rs b/src/query_router.rs index 9d9db5ed..be9f219f 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -279,22 +279,11 @@ impl QueryRouter { // Parse (prepared statement) 'P' => { - // Start after the code and len - let mut start = std::mem::size_of::() + std::mem::size_of::(); + // Reads statement name + message_cursor.read_string().unwrap(); - // Skip the name of the prepared statement. - while buf[start] != 0 && start < buf.len() { - start += 1; - } - start += 1; // Skip terminating null - - // Find the end of the prepared stmt (\0) - let mut end = start; - while buf[end] != 0 && end < buf.len() { - end += 1; - } - - let query = String::from_utf8_lossy(&buf[start..end]).to_string(); + // Reads query string + let query = message_cursor.read_string().unwrap(); debug!("Prepared statement: '{}'", query); From 5f8473cfdc612fe6a9587dd84b2002ea254f95c0 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 21 Nov 2022 21:09:15 -0500 Subject: [PATCH 29/35] Creates send server message to client function to both send to client and clear server buffer Fixes incorrect log variable bug --- src/admin.rs | 3 +-- src/client.rs | 52 +++++++++++++++++++++------------------------ src/lib.rs | 25 +++------------------- src/messages.rs | 18 ++++++++++++++++ src/query_router.rs | 3 +-- src/server.rs | 30 +++++++++++++++++++++++++- 6 files changed, 76 insertions(+), 55 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index 5879114a..4ad3b14f 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -7,11 +7,10 @@ use tokio::time::Instant; use crate::config::{get_config, reload_config, VERSION}; use crate::errors::Error; use crate::messages::*; -use crate::pool::get_all_pools; +use crate::pool::{get_all_pools, ClientServerMap}; use crate::stats::{ get_address_stats, get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState, }; -use crate::ClientServerMap; pub fn generate_server_info_for_admin() -> BytesMut { let mut server_info = BytesMut::new(); diff --git a/src/client.rs b/src/client.rs index 4923441b..b7ec6cff 100644 --- a/src/client.rs +++ b/src/client.rs @@ -36,7 +36,7 @@ pub struct Client { /// We buffer the writes ourselves because we know the protocol /// better than a stock buffer. - write: T, + pub write: T, /// Internal buffer, where we place messages until we have to flush /// them to the backend. @@ -433,7 +433,7 @@ where let code = match read.read_u8().await { Ok(p) => p, - Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))), + Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), }; // PasswordMessage @@ -446,14 +446,14 @@ where let len = match read.read_i32().await { Ok(len) => len, - Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))), + Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), }; let mut password_response = vec![0u8; (len - 4) as usize]; match read.read_exact(&mut password_response).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))), + Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), }; // Authenticate admin user. @@ -467,10 +467,10 @@ where ); if password_hash != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); + warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); wrong_password(&mut write, username).await?; - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))); + return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); } (false, generate_server_info_for_admin()) @@ -489,7 +489,7 @@ where ) .await?; - return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))); + return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); } }; @@ -497,10 +497,10 @@ where let password_hash = md5_hash_password(username, &pool.settings.user.password, &salt); if password_hash != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); + warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); wrong_password(&mut write, username).await?; - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))); + return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); } let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; @@ -705,7 +705,7 @@ where ) .await?; - return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name))); + return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.username, self.pool_name, self.application_name))); } }; query_router.update_pool_settings(pool.settings.clone()); @@ -973,15 +973,7 @@ where self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, &server.message_buffer).await { - Ok(_) => (), - Err(err) => { - server.mark_bad(); - return Err(err); - } - }; - - server.message_buffer.clear(); + server.send_buffered_messages_to_client(self).await?; if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -1037,15 +1029,7 @@ where loop { self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, &server.message_buffer).await { - Ok(_) => (), - Err(err) => { - server.mark_bad(); - return Err(err); - } - }; - - server.message_buffer.clear(); + server.send_buffered_messages_to_client(self).await?; if !server.is_data_available() { break; @@ -1132,6 +1116,18 @@ where } } } + + pub fn get_username(&self) -> &String { + &self.username + } + + pub fn get_pool_name(&self) -> &String { + &self.pool_name + } + + pub fn get_application_name(&self) -> &String { + &self.application_name + } } impl Drop for Client { diff --git a/src/lib.rs b/src/lib.rs index 43d52fba..8f6c5def 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,11 @@ -use std::io::{BufRead, Cursor}; - -use bytes::BytesMut; -use errors::Error; - +pub mod admin; +pub mod client; pub mod config; pub mod constants; pub mod errors; pub mod messages; pub mod pool; +pub mod query_router; pub mod scram; pub mod server; pub mod sharding; @@ -35,20 +33,3 @@ pub fn format_duration(duration: &chrono::Duration) -> String { days, hours, minutes, seconds, milliseconds ) } - -pub trait BytesMutReader { - fn read_string(&mut self) -> Result; -} - -impl BytesMutReader for Cursor<&BytesMut> { - /// Should only be used when reading strings from the message protocol. - /// - /// Can be used to read multiple strings from the same message which are separated by the null byte - fn read_string(&mut self) -> Result { - let mut buf = vec![]; - match self.read_until(b'\0', &mut buf) { - Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), - Err(err) => return Err(Error::ParseBytesError(err.to_string())), - } - } -} diff --git a/src/messages.rs b/src/messages.rs index 578a7545..c34ffb2b 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -7,6 +7,7 @@ use tokio::net::TcpStream; use crate::errors::Error; use std::collections::HashMap; +use std::io::{BufRead, Cursor}; use std::mem; /// Postgres data type mappings @@ -539,3 +540,20 @@ pub fn server_parameter_message(key: &str, value: &str) -> BytesMut { server_info } + +pub trait BytesMutReader { + fn read_string(&mut self) -> Result; +} + +impl BytesMutReader for Cursor<&BytesMut> { + /// Should only be used when reading strings from the message protocol. + /// + /// Can be used to read multiple strings from the same message which are separated by the null byte + fn read_string(&mut self) -> Result { + let mut buf = vec![]; + match self.read_until(b'\0', &mut buf) { + Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), + Err(err) => return Err(Error::ParseBytesError(err.to_string())), + } + } +} diff --git a/src/query_router.rs b/src/query_router.rs index be9f219f..69aec9be 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -10,14 +10,13 @@ use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; use crate::config::Role; +use crate::messages::BytesMutReader; use crate::pool::PoolSettings; use crate::sharding::Sharder; use std::collections::BTreeSet; use std::io::Cursor; -use pgcat::BytesMutReader; - /// Regexes used to parse custom commands. const CUSTOM_SQL_REGEXES: [&str; 7] = [ r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$", diff --git a/src/server.rs b/src/server.rs index e853293e..a5cded9d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,6 +10,7 @@ use tokio::net::{ TcpStream, }; +use crate::client::Client; use crate::config::{Address, User}; use crate::constants::*; use crate::errors::Error; @@ -33,7 +34,7 @@ pub struct Server { write: OwnedWriteHalf, /// Our server response buffer. We buffer data before we give it to the client. - pub message_buffer: BytesMut, + message_buffer: BytesMut, /// Server information the server sent us over on startup. server_info: BytesMut, @@ -535,6 +536,33 @@ impl Server { Ok(()) } + /// This function takes a client and sends the buffered server messages to the client. + /// This will also clear the the server message buffer + pub async fn send_buffered_messages_to_client( + &mut self, + client: &mut Client, + ) -> Result<(), Error> + where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, + { + if self.message_buffer.is_empty() { + return Err(Error::ClientError(format!("Attempting to send empty buffered server message to client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", client.get_username(), client.get_pool_name(), client.get_application_name()))); + } + + match write_all_half(&mut client.write, &self.message_buffer).await { + Ok(_) => {} + Err(_) => { + self.mark_bad(); // We have some weird state with the server and buffer here + return Err(Error::SocketError(format!("Error sending server message to client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", client.get_username(), client.get_pool_name(), client.get_application_name()))); + } + }; + + self.message_buffer.clear(); + + Ok(()) + } + /// If the server is still inside a transaction. /// If the client disconnects while the server is in a transaction, we will clean it up. pub fn in_transaction(&self) -> bool { From 382e0b831e0194dc7586e3ec3c97e770d0edd3dc Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 28 Nov 2022 15:30:26 -0500 Subject: [PATCH 30/35] Ensures buffer is cleared before it checked in --- src/server.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/server.rs b/src/server.rs index a5cded9d..e86444ca 100644 --- a/src/server.rs +++ b/src/server.rs @@ -635,6 +635,12 @@ impl Server { // Pgbouncer behavior is to close the server connection but that can cause // server connection thrashing if clients repeatedly do this. // Instead, we ROLLBACK that transaction before putting the connection back in the pool + + if !self.message_buffer.is_empty() { + warn!("Server message buffer was not cleated before cleanup"); + self.message_buffer.clear(); + } + if self.in_transaction() { warn!("Server returned while still in transaction, rolling back transaction"); self.query("ROLLBACK").await?; From 459e5d474b302c94ff2a8234f26a561d8bd6aebb Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 28 Nov 2022 16:33:32 -0500 Subject: [PATCH 31/35] revert? --- src/server.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/server.rs b/src/server.rs index e86444ca..6e487a78 100644 --- a/src/server.rs +++ b/src/server.rs @@ -636,10 +636,10 @@ impl Server { // server connection thrashing if clients repeatedly do this. // Instead, we ROLLBACK that transaction before putting the connection back in the pool - if !self.message_buffer.is_empty() { - warn!("Server message buffer was not cleated before cleanup"); - self.message_buffer.clear(); - } + // if !self.message_buffer.is_empty() { + // warn!("Server message buffer was not cleated before cleanup"); + // self.message_buffer.clear(); + // } if self.in_transaction() { warn!("Server returned while still in transaction, rolling back transaction"); From ae4078ead9c6180c046c7a10a052765171e1eefa Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Mon, 28 Nov 2022 20:54:23 -0500 Subject: [PATCH 32/35] try again --- src/server.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/server.rs b/src/server.rs index 6e487a78..e86444ca 100644 --- a/src/server.rs +++ b/src/server.rs @@ -636,10 +636,10 @@ impl Server { // server connection thrashing if clients repeatedly do this. // Instead, we ROLLBACK that transaction before putting the connection back in the pool - // if !self.message_buffer.is_empty() { - // warn!("Server message buffer was not cleated before cleanup"); - // self.message_buffer.clear(); - // } + if !self.message_buffer.is_empty() { + warn!("Server message buffer was not cleated before cleanup"); + self.message_buffer.clear(); + } if self.in_transaction() { warn!("Server returned while still in transaction, rolling back transaction"); From 2a84549286e8c56ce98cdaebfc9c458dcaf62be2 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 21 Dec 2022 15:41:56 -0500 Subject: [PATCH 33/35] Update comments --- src/server.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/server.rs b/src/server.rs index eedae4f2..d4806aaa 100644 --- a/src/server.rs +++ b/src/server.rs @@ -635,16 +635,17 @@ impl Server { /// Perform any necessary cleanup before putting the server /// connection back in the pool pub async fn checkin_cleanup(&mut self) -> Result<(), Error> { - // Client disconnected with an open transaction on the server connection. - // Pgbouncer behavior is to close the server connection but that can cause - // server connection thrashing if clients repeatedly do this. - // Instead, we ROLLBACK that transaction before putting the connection back in the pool + // Incase the message buffer wasn't flushed properly if !self.message_buffer.is_empty() { warn!("Server message buffer was not cleated before cleanup"); self.message_buffer.clear(); } + // Client disconnected with an open transaction on the server connection. + // Pgbouncer behavior is to close the server connection but that can cause + // server connection thrashing if clients repeatedly do this. + // Instead, we ROLLBACK that transaction before putting the connection back in the pool if self.in_transaction() { warn!("Server returned while still in transaction, rolling back transaction"); self.query("ROLLBACK").await?; From fb891fc1619f6eb4f19c5c4133f715ba8b5e7230 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 21 Dec 2022 15:45:09 -0500 Subject: [PATCH 34/35] fmt --- src/server.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/server.rs b/src/server.rs index d4806aaa..0c8127a1 100644 --- a/src/server.rs +++ b/src/server.rs @@ -635,7 +635,6 @@ impl Server { /// Perform any necessary cleanup before putting the server /// connection back in the pool pub async fn checkin_cleanup(&mut self) -> Result<(), Error> { - // Incase the message buffer wasn't flushed properly if !self.message_buffer.is_empty() { warn!("Server message buffer was not cleated before cleanup"); From 9937aba5148e53b069505e33482209db90131847 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 21 Dec 2022 15:46:16 -0500 Subject: [PATCH 35/35] fix --- src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server.rs b/src/server.rs index 0c8127a1..826bc607 100644 --- a/src/server.rs +++ b/src/server.rs @@ -515,7 +515,7 @@ impl Server { // CopyData 'd' => { // Don't flush yet, buffer until we reach limit - if self.buffer.len() >= 8196 { + if self.message_buffer.len() >= 8196 { break; } }