diff --git a/src/admin.rs b/src/admin.rs index 4460f982..4ad3b14f 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -7,11 +7,10 @@ use tokio::time::Instant; use crate::config::{get_config, reload_config, VERSION}; use crate::errors::Error; use crate::messages::*; -use crate::pool::get_all_pools; +use crate::pool::{get_all_pools, ClientServerMap}; use crate::stats::{ get_address_stats, get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState, }; -use crate::ClientServerMap; pub fn generate_server_info_for_admin() -> BytesMut { let mut server_info = BytesMut::new(); @@ -171,7 +170,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show PgCat version. @@ -189,7 +188,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show utilization of connection pools for each shard and replicas. @@ -250,7 +249,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show shards and replicas. @@ -317,7 +316,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Ignore any SET commands the client sends. @@ -349,7 +348,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Shows current configuration. @@ -395,7 +394,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show shard and replicas statistics. @@ -455,7 +454,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show currently connected clients @@ -505,7 +504,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show currently connected servers @@ -559,5 +558,5 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } diff --git a/src/client.rs b/src/client.rs index b55906b2..b7ec6cff 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,6 +2,7 @@ use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace, warn}; use std::collections::HashMap; +use std::io::Cursor; use std::time::Instant; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; @@ -35,11 +36,11 @@ pub struct Client { /// We buffer the writes ourselves because we know the protocol /// better than a stock buffer. - write: T, + pub write: T, /// Internal buffer, where we place messages until we have to flush /// them to the backend. - buffer: BytesMut, + message_buffer: BytesMut, /// Address addr: std::net::SocketAddr, @@ -380,7 +381,7 @@ where admin_only: bool, ) -> Result, Error> { let stats = get_reporter(); - let parameters = parse_startup(bytes.clone())?; + let parameters = parse_startup(bytes)?; // This parameter is mandatory by the protocol. let username = match parameters.get("user") { @@ -432,7 +433,7 @@ where let code = match read.read_u8().await { Ok(p) => p, - Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))), + Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), }; // PasswordMessage @@ -445,14 +446,14 @@ where let len = match read.read_i32().await { Ok(len) => len, - Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))), + Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), }; let mut password_response = vec![0u8; (len - 4) as usize]; match read.read_exact(&mut password_response).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))), + Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), }; // Authenticate admin user. @@ -466,10 +467,10 @@ where ); if password_hash != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); + warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); wrong_password(&mut write, username).await?; - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))); + return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); } (false, generate_server_info_for_admin()) @@ -488,7 +489,7 @@ where ) .await?; - return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))); + return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); } }; @@ -496,10 +497,10 @@ where let password_hash = md5_hash_password(username, &pool.settings.user.password, &salt); if password_hash != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); + warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); wrong_password(&mut write, username).await?; - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))); + return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); } let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; @@ -520,7 +521,7 @@ where read: BufReader::new(read), write, addr, - buffer: BytesMut::with_capacity(8196), + message_buffer: BytesMut::with_capacity(8196), cancel_mode: false, transaction_mode, process_id, @@ -553,8 +554,8 @@ where Ok(Client { read: BufReader::new(read), write, + message_buffer: BytesMut::with_capacity(8196), addr, - buffer: BytesMut::with_capacity(8196), cancel_mode: true, transaction_mode: false, process_id, @@ -612,6 +613,8 @@ where self.application_name.clone(), ); + let mut clear_buffer = false; + // Our custom protocol loop. // We expect the client to either start a transaction with regular queries // or issue commands for our sharding and server selection protocol. @@ -621,13 +624,19 @@ where self.transaction_mode ); + if clear_buffer { + self.message_buffer.clear(); + } + + clear_buffer = true; + // Read a complete message from the client, which normally would be // either a `Q` (query) or `P` (prepare, extended protocol). // We can parse it here before grabbing a server from the pool, // in case the client is sending some custom protocol messages, e.g. // SET SHARDING KEY TO 'bigint'; - let message = tokio::select! { + let message_start = tokio::select! { _ = self.shutdown.recv() => { if !self.admin { error_response_terminal( @@ -639,13 +648,18 @@ where // Admin clients ignore shutdown. else { - read_message(&mut self.read).await? + read_message_into_buffer(&mut self.read, &mut self.message_buffer).await? } }, - message_result = read_message(&mut self.read) => message_result? + message_result = read_message_into_buffer(&mut self.read, &mut self.message_buffer) => message_result? }; - match message[0] as char { + let mut message_cursor = Cursor::new(&self.message_buffer); + message_cursor.advance(message_start); + + let message_code = message_cursor.get_u8(); + + match message_code as char { // Buffer extended protocol messages even if we do not have // a server connection yet. Hopefully, when we get the S message // we'll be able to allocate a connection. Also, clients do not expect @@ -654,7 +668,7 @@ where // to the client so we buffer them and defer the decision to error out or not // to when we get the S message 'P' | 'B' | 'D' | 'E' => { - self.buffer.put(&message[..]); + clear_buffer = false; continue; } 'X' => { @@ -667,7 +681,12 @@ where // Handle admin database queries. if self.admin { debug!("Handling admin command"); - handle_admin(&mut self.write, message, self.client_server_map.clone()).await?; + handle_admin( + &mut self.write, + self.message_buffer.clone(), + self.client_server_map.clone(), + ) + .await?; continue; } @@ -686,18 +705,18 @@ where ) .await?; - return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name))); + return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.username, self.pool_name, self.application_name))); } }; query_router.update_pool_settings(pool.settings.clone()); let current_shard = query_router.shard(); // Handle all custom protocol commands, if any. - match query_router.try_execute_command(message.clone()) { + match query_router.try_execute_command(&self.message_buffer) { // Normal query, not a custom command. None => { if query_router.query_parser_enabled() { - query_router.infer(message.clone()); + query_router.infer(&self.message_buffer); } } @@ -777,15 +796,15 @@ where // but we were unable to grab a connection from the pool // We'll send back an error message and clean the extended // protocol buffer - if message[0] as char == 'S' { + if message_code as char == 'S' { error!("Got Sync message but failed to get a connection from the pool"); - self.buffer.clear(); + self.message_buffer.clear(); } error_response(&mut self.write, "could not get connection from the pool") .await?; error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}", - self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err); + self.pool_name, self.username, query_router.shard(), query_router.role(), err); continue; } }; @@ -817,7 +836,7 @@ where // Set application_name. server.set_name(&self.application_name).await?; - let mut initial_message = Some(message); + let mut initial_message_start = Some(message_start); // Transaction loop. Multiple queries can be issued by the client here. // The connection belongs to the client until the transaction is over, @@ -826,11 +845,13 @@ where // If the client is in session mode, no more custom protocol // commands will be accepted. loop { - let message = match initial_message { + let message_start = match initial_message_start { None => { trace!("Waiting for message inside transaction or in session mode"); - match read_message(&mut self.read).await { + match read_message_into_buffer(&mut self.read, &mut self.message_buffer) + .await + { Ok(message) => message, Err(err) => { // Client disconnected inside a transaction. @@ -841,18 +862,18 @@ where } } } - Some(message) => { - initial_message = None; - message + Some(message_start) => { + initial_message_start = None; + message_start } }; + let mut message_cursor = Cursor::new(&self.message_buffer); + message_cursor.advance(message_start); + // The message will be forwarded to the server intact. We still would like to // parse it below to figure out what to do with it. - - // Safe to unwrap because we know this message has a certain length and has the code - // This reads the first byte without advancing the internal pointer and mutating the bytes - let code = *message.get(0).unwrap() as char; + let code = message_cursor.get_u8() as char; trace!("Message: {}", code); @@ -861,7 +882,7 @@ where 'Q' => { debug!("Sending query to server"); - self.send_and_receive_loop(code, message, server, &address, &pool) + self.send_and_receive_loop(code, server, &address, &pool) .await?; if !server.in_transaction() { @@ -886,36 +907,27 @@ where // Parse // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. - 'P' => { - self.buffer.put(&message[..]); - } + 'P' => {} // Bind // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' - 'B' => { - self.buffer.put(&message[..]); - } + 'B' => {} // Describe // Command a client can issue to describe a previously prepared named statement. - 'D' => { - self.buffer.put(&message[..]); - } + 'D' => {} // Execute // Execute a prepared statement prepared in `P` and bound in `B`. - 'E' => { - self.buffer.put(&message[..]); - } + 'E' => {} // Sync // Frontend (client) is asking for the query result now. 'S' => { debug!("Sending query to server"); - self.buffer.put(&message[..]); - - let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; + let first_message_code = + (*self.message_buffer.get(0).unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' { @@ -923,7 +935,7 @@ where // P followed by 32 int followed by null-terminated statement name // So message code should be in offset 0 of the buffer, first character // in prepared statement name would be index 5 - let first_char_in_name = *self.buffer.get(5).unwrap_or(&0); + let first_char_in_name = *self.message_buffer.get(5).unwrap_or(&0); if first_char_in_name != 0 { // This is a named prepared statement // Server connection state will need to be cleared at checkin @@ -931,16 +943,8 @@ where } } - self.send_and_receive_loop( - code, - self.buffer.clone(), - server, - &address, - &pool, - ) - .await?; - - self.buffer.clear(); + self.send_and_receive_loop(code, server, &address, &pool) + .await?; if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -957,25 +961,19 @@ where 'd' => { // Forward the data to the server, // don't buffer it since it can be rather large. - self.send_server_message(server, message, &address, &pool) + self.send_client_message_to_server(server, &address, &pool) .await?; } // CopyDone or CopyFail // Copy is done, successfully or not. 'c' | 'f' => { - self.send_server_message(server, message, &address, &pool) + self.send_client_message_to_server(server, &address, &pool) .await?; - let response = self.receive_server_message(server, &address, &pool).await?; + self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, response).await { - Ok(_) => (), - Err(err) => { - server.mark_bad(); - return Err(err); - } - }; + server.send_buffered_messages_to_client(self).await?; if !server.in_transaction() { self.stats.transaction(self.process_id, server.server_id()); @@ -1016,29 +1014,22 @@ where async fn send_and_receive_loop( &mut self, code: char, - message: BytesMut, server: &mut Server, address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { debug!("Sending {} to server", code); - self.send_server_message(server, message, address, pool) + self.send_client_message_to_server(server, &address, &pool) .await?; let query_start = Instant::now(); // Read all data the server has to offer, which can be multiple messages // buffered in 8196 bytes chunks. loop { - let response = self.receive_server_message(server, address, pool).await?; + self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, response).await { - Ok(_) => (), - Err(err) => { - server.mark_bad(); - return Err(err); - } - }; + server.send_buffered_messages_to_client(self).await?; if !server.is_data_available() { break; @@ -1055,17 +1046,20 @@ where Ok(()) } - async fn send_server_message( - &self, + async fn send_client_message_to_server( + &mut self, server: &mut Server, - message: BytesMut, address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { - match server.send(message).await { - Ok(_) => Ok(()), + match server.send(&self.message_buffer).await { + Ok(_) => { + self.message_buffer.clear(); + Ok(()) + } Err(err) => { pool.ban(address, self.process_id); + self.message_buffer.clear(); Err(err) } } @@ -1076,7 +1070,7 @@ where server: &mut Server, address: &Address, pool: &ConnectionPool, - ) -> Result { + ) -> Result<(), Error> { if pool.settings.user.statement_timeout > 0 { match tokio::time::timeout( tokio::time::Duration::from_millis(pool.settings.user.statement_timeout), @@ -1085,7 +1079,7 @@ where .await { Ok(result) => match result { - Ok(message) => Ok(message), + Ok(_) => Ok(()), Err(err) => { pool.ban(address, self.process_id); error_response_terminal( @@ -1109,7 +1103,7 @@ where } } else { match server.recv().await { - Ok(message) => Ok(message), + Ok(_) => Ok(()), Err(err) => { pool.ban(address, self.process_id); error_response_terminal( @@ -1122,6 +1116,18 @@ where } } } + + pub fn get_username(&self) -> &String { + &self.username + } + + pub fn get_pool_name(&self) -> &String { + &self.pool_name + } + + pub fn get_application_name(&self) -> &String { + &self.application_name + } } impl Drop for Client { diff --git a/src/errors.rs b/src/errors.rs index 7789a8a7..4ac23a85 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -13,4 +13,5 @@ pub enum Error { TlsError, StatementTimeout, ShuttingDown, + ParseBytesError(String), } diff --git a/src/lib.rs b/src/lib.rs index e9a683f3..8f6c5def 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,11 @@ +pub mod admin; +pub mod client; pub mod config; pub mod constants; pub mod errors; pub mod messages; pub mod pool; +pub mod query_router; pub mod scram; pub mod server; pub mod sharding; diff --git a/src/messages.rs b/src/messages.rs index 826508ee..c34ffb2b 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -7,6 +7,7 @@ use tokio::net::TcpStream; use crate::errors::Error; use std::collections::HashMap; +use std::io::{BufRead, Cursor}; use std::mem; /// Postgres data type mappings @@ -258,7 +259,7 @@ where res.put_i32(len); res.put_slice(&set_complete[..]); - write_all_half(stream, res).await?; + write_all_half(stream, &res).await?; ready_for_query(stream).await } @@ -308,7 +309,7 @@ where res.put_i32(error.len() as i32 + 4); res.put(error); - write_all_half(stream, res).await + write_all_half(stream, &res).await } pub async fn wrong_password(stream: &mut S, user: &str) -> Result<(), Error> @@ -370,7 +371,7 @@ where // CommandComplete res.put(command_complete("SELECT 1")); - write_all_half(stream, res).await?; + write_all_half(stream, &res).await?; ready_for_query(stream).await } @@ -459,21 +460,26 @@ where } /// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). -pub async fn write_all_half(stream: &mut S, buf: BytesMut) -> Result<(), Error> +pub async fn write_all_half(stream: &mut S, buf: &BytesMut) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { - match stream.write_all(&buf).await { + match stream.write_all(buf).await { Ok(_) => Ok(()), Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))), } } /// Read a complete message from the socket. -pub async fn read_message(stream: &mut S) -> Result +pub async fn read_message_into_buffer( + stream: &mut S, + buffer: &mut BytesMut, +) -> Result where S: tokio::io::AsyncRead + std::marker::Unpin, { + let starting_point = buffer.len(); + let code = match stream.read_u8().await { Ok(code) => code, Err(_) => { @@ -493,9 +499,19 @@ where } }; - let mut buf = vec![0u8; len as usize - 4]; + buffer.put_u8(code); + buffer.put_i32(len); - match stream.read_exact(&mut buf).await { + buffer.resize(buffer.len() + len as usize - mem::size_of::(), b'0'); + + match stream + .read_exact( + &mut buffer[starting_point + mem::size_of::() + mem::size_of::() + ..starting_point + mem::size_of::() + mem::size_of::() + len as usize + - mem::size_of::()], + ) + .await + { Ok(_) => (), Err(_) => { return Err(Error::SocketError(format!( @@ -505,13 +521,7 @@ where } }; - let mut bytes = BytesMut::with_capacity(len as usize + 1); - - bytes.put_u8(code); - bytes.put_i32(len); - bytes.put_slice(&buf); - - Ok(bytes) + Ok(starting_point) } pub fn server_parameter_message(key: &str, value: &str) -> BytesMut { @@ -530,3 +540,20 @@ pub fn server_parameter_message(key: &str, value: &str) -> BytesMut { server_info } + +pub trait BytesMutReader { + fn read_string(&mut self) -> Result; +} + +impl BytesMutReader for Cursor<&BytesMut> { + /// Should only be used when reading strings from the message protocol. + /// + /// Can be used to read multiple strings from the same message which are separated by the null byte + fn read_string(&mut self) -> Result { + let mut buf = vec![]; + match self.read_until(b'\0', &mut buf) { + Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), + Err(err) => return Err(Error::ParseBytesError(err.to_string())), + } + } +} diff --git a/src/query_router.rs b/src/query_router.rs index 50905716..69aec9be 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -10,10 +10,12 @@ use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; use crate::config::Role; +use crate::messages::BytesMutReader; use crate::pool::PoolSettings; use crate::sharding::Sharder; use std::collections::BTreeSet; +use std::io::Cursor; /// Regexes used to parse custom commands. const CUSTOM_SQL_REGEXES: [&str; 7] = [ @@ -107,16 +109,18 @@ impl QueryRouter { } /// Try to parse a command and execute it. - pub fn try_execute_command(&mut self, mut buf: BytesMut) -> Option<(Command, String)> { - let code = buf.get_u8() as char; + pub fn try_execute_command(&mut self, buf: &BytesMut) -> Option<(Command, String)> { + let mut message_cursor = Cursor::new(buf); + + let code = message_cursor.get_u8() as char; // Only simple protocol supported for commands. if code != 'Q' { return None; } - let len = buf.get_i32() as usize; - let query = String::from_utf8_lossy(&buf[..len - 5]).to_string(); // Ignore the terminating NULL. + let _len = message_cursor.get_i32() as usize; + let query = message_cursor.read_string().unwrap(); let regex_set = match CUSTOM_SQL_REGEX_SET.get() { Some(regex_set) => regex_set, @@ -256,37 +260,29 @@ impl QueryRouter { } /// Try to infer which server to connect to based on the contents of the query. - pub fn infer(&mut self, mut buf: BytesMut) -> bool { + pub fn infer(&mut self, buf: &BytesMut) -> bool { debug!("Inferring role"); - let code = buf.get_u8() as char; - let len = buf.get_i32() as usize; + let mut message_cursor = Cursor::new(buf); + + let code = message_cursor.get_u8() as char; + let _len = message_cursor.get_i32() as usize; let query = match code { // Query 'Q' => { - let query = String::from_utf8_lossy(&buf[..len - 5]).to_string(); + let query = message_cursor.read_string().unwrap(); debug!("Query: '{}'", query); query } // Parse (prepared statement) 'P' => { - let mut start = 0; - - // Skip the name of the prepared statement. - while buf[start] != 0 && start < buf.len() { - start += 1; - } - start += 1; // Skip terminating null - - // Find the end of the prepared stmt (\0) - let mut end = start; - while buf[end] != 0 && end < buf.len() { - end += 1; - } + // Reads statement name + message_cursor.read_string().unwrap(); - let query = String::from_utf8_lossy(&buf[start..end]).to_string(); + // Reads query string + let query = message_cursor.read_string().unwrap(); debug!("Prepared statement: '{}'", query); @@ -519,10 +515,10 @@ mod test { fn test_infer_replica() { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); + assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr.query_parser_enabled()); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); let queries = vec![ simple_query("SELECT * FROM items WHERE id = 5"), @@ -534,7 +530,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Replica)); } } @@ -553,7 +549,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Primary)); } } @@ -563,9 +559,9 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SELECT * FROM items WHERE id = 5"); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), None); } @@ -573,8 +569,8 @@ mod test { fn test_infer_parse_prepared() { QueryRouter::setup(); let mut qr = QueryRouter::new(); - qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); let prepared_stmt = BytesMut::from( &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], @@ -585,7 +581,7 @@ mod test { res.put(prepared_stmt); res.put_i16(0); - assert!(qr.infer(res)); + assert!(qr.infer(&res)); assert_eq!(qr.role(), Some(Role::Replica)); } @@ -668,7 +664,7 @@ mod test { // SetShardingKey let query = simple_query("SET SHARDING KEY TO 13"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetShardingKey, String::from("0"))) ); assert_eq!(qr.shard(), 0); @@ -676,7 +672,7 @@ mod test { // SetShard let query = simple_query("SET SHARD TO '1'"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetShard, String::from("1"))) ); assert_eq!(qr.shard(), 1); @@ -684,7 +680,7 @@ mod test { // ShowShard let query = simple_query("SHOW SHARD"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::ShowShard, String::from("1"))) ); @@ -702,7 +698,7 @@ mod test { for (idx, role) in roles.iter().enumerate() { let query = simple_query(&format!("SET SERVER ROLE TO '{}'", role)); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::SetServerRole, String::from(*role))) ); assert_eq!(qr.role(), verify_roles[idx],); @@ -711,7 +707,7 @@ mod test { // ShowServerRole let query = simple_query("SHOW SERVER ROLE"); assert_eq!( - qr.try_execute_command(query), + qr.try_execute_command(&query), Some((Command::ShowServerRole, String::from(*role))) ); } @@ -721,14 +717,14 @@ mod test { for (idx, primary_reads) in primary_reads.iter().enumerate() { assert_eq!( - qr.try_execute_command(simple_query(&format!( + qr.try_execute_command(&simple_query(&format!( "SET PRIMARY READS TO {}", primary_reads ))), Some((Command::SetPrimaryReads, String::from(*primary_reads))) ); assert_eq!( - qr.try_execute_command(simple_query("SHOW PRIMARY READS")), + qr.try_execute_command(&simple_query("SHOW PRIMARY READS")), Some(( Command::ShowPrimaryReads, String::from(primary_reads_enabled[idx]) @@ -742,23 +738,23 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SET SERVER ROLE TO 'auto'"); - assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); - assert!(qr.try_execute_command(query) != None); + assert!(qr.try_execute_command(&query) != None); assert!(qr.query_parser_enabled()); assert_eq!(qr.role(), None); let query = simple_query("INSERT INTO test_table VALUES (1)"); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Primary)); let query = simple_query("SELECT * FROM test_table"); - assert!(qr.infer(query)); + assert!(qr.infer(&query)); assert_eq!(qr.role(), Some(Role::Replica)); assert!(qr.query_parser_enabled()); let query = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(query) != None); + assert!(qr.try_execute_command(&query) != None); assert!(!qr.query_parser_enabled()); } @@ -793,16 +789,16 @@ mod test { assert!(!qr.primary_reads_enabled()); let q1 = simple_query("SET SERVER ROLE TO 'primary'"); - assert!(qr.try_execute_command(q1) != None); + assert!(qr.try_execute_command(&q1) != None); assert_eq!(qr.active_role.unwrap(), Role::Primary); let q2 = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(q2) != None); + assert!(qr.try_execute_command(&q2) != None); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); // Here we go :) let q3 = simple_query("SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)"); - assert!(qr.infer(q3)); + assert!(qr.infer(&q3)); assert_eq!(qr.shard(), 1); } @@ -811,13 +807,13 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.infer(simple_query("BEGIN; SELECT 1; COMMIT;"))); + assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;"))); assert_eq!(qr.role(), Role::Primary); - assert!(qr.infer(simple_query("SELECT 1; SELECT 2;"))); + assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;"))); assert_eq!(qr.role(), Role::Replica); - assert!(qr.infer(simple_query( + assert!(qr.infer(&simple_query( "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;" ))); assert_eq!(qr.role(), Role::Primary); diff --git a/src/server.rs b/src/server.rs index 05a3b770..826bc607 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,7 +2,7 @@ /// Here we are pretending to the a Postgres client. use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace, warn}; -use std::io::Read; +use std::io::{Cursor, Read}; use std::time::SystemTime; use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::{ @@ -10,6 +10,7 @@ use tokio::net::{ TcpStream, }; +use crate::client::Client; use crate::config::{Address, User}; use crate::constants::*; use crate::errors::Error; @@ -33,7 +34,7 @@ pub struct Server { write: OwnedWriteHalf, /// Our server response buffer. We buffer data before we give it to the client. - buffer: BytesMut, + message_buffer: BytesMut, /// Server information the server sent us over on startup. server_info: BytesMut, @@ -318,12 +319,12 @@ impl Server { let mut server = Server { address: address.clone(), read: BufReader::new(read), - write, - buffer: BytesMut::with_capacity(8196), - server_info, - server_id, - process_id, - secret_key, + write: write, + message_buffer: BytesMut::with_capacity(8196), + server_info: server_info, + server_id: server_id, + process_id: process_id, + secret_key: secret_key, in_transaction: false, data_available: false, bad: false, @@ -381,7 +382,7 @@ impl Server { } /// Send messages to the server from the client. - pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> { + pub async fn send(&mut self, messages: &BytesMut) -> Result<(), Error> { self.stats.data_sent(messages.len(), self.server_id); match write_all_half(&mut self.write, messages).await { @@ -401,29 +402,30 @@ impl Server { /// Receive data from the server in response to a client request. /// This method must be called multiple times while `self.is_data_available()` is true /// in order to receive all data the server has to offer. - pub async fn recv(&mut self) -> Result { + pub async fn recv(&mut self) -> Result<(), Error> { loop { - let mut message = match read_message(&mut self.read).await { - Ok(message) => message, - Err(err) => { - error!("Terminating server because of: {:?}", err); - self.bad = true; - return Err(err); - } - }; + let message_start = + match read_message_into_buffer(&mut self.read, &mut self.message_buffer).await { + Ok(message) => message, + Err(err) => { + error!("Terminating server because of: {:?}", err); + self.bad = true; + return Err(err); + } + }; - // Buffer the message we'll forward to the client later. - self.buffer.put(&message[..]); + let mut message_cursor = Cursor::new(&self.message_buffer); + message_cursor.advance(message_start); - let code = message.get_u8() as char; - let _len = message.get_i32(); + let code = message_cursor.get_u8() as char; + let _len = message_cursor.get_i32(); trace!("Message: {}", code); match code { // ReadyForQuery 'Z' => { - let transaction_state = message.get_u8() as char; + let transaction_state = message_cursor.get_u8() as char; match transaction_state { // In transaction. @@ -460,7 +462,7 @@ impl Server { // CommandComplete 'C' => { let mut command_tag = String::new(); - match message.reader().read_to_string(&mut command_tag) { + match message_cursor.reader().read_to_string(&mut command_tag) { Ok(_) => { // Non-exhaustive list of commands that are likely to change session variables/resources // which can leak between clients. This is a best effort to block bad clients @@ -496,7 +498,7 @@ impl Server { self.data_available = true; // Don't flush yet, the more we buffer, the faster this goes...up to a limit. - if self.buffer.len() >= 8196 { + if self.message_buffer.len() >= 8196 { break; } } @@ -513,7 +515,7 @@ impl Server { // CopyData 'd' => { // Don't flush yet, buffer until we reach limit - if self.buffer.len() >= 8196 { + if self.message_buffer.len() >= 8196 { break; } } @@ -528,19 +530,41 @@ impl Server { }; } - let bytes = self.buffer.clone(); - // Keep track of how much data we got from the server for stats. - self.stats.data_received(bytes.len(), self.server_id); - - // Clear the buffer for next query. - self.buffer.clear(); + self.stats + .data_received(self.message_buffer.len(), self.server_id); // Successfully received data from server self.last_activity = SystemTime::now(); - // Pass the data back to the client. - Ok(bytes) + Ok(()) + } + + /// This function takes a client and sends the buffered server messages to the client. + /// This will also clear the the server message buffer + pub async fn send_buffered_messages_to_client( + &mut self, + client: &mut Client, + ) -> Result<(), Error> + where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, + { + if self.message_buffer.is_empty() { + return Err(Error::ClientError(format!("Attempting to send empty buffered server message to client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", client.get_username(), client.get_pool_name(), client.get_application_name()))); + } + + match write_all_half(&mut client.write, &self.message_buffer).await { + Ok(_) => {} + Err(_) => { + self.mark_bad(); // We have some weird state with the server and buffer here + return Err(Error::SocketError(format!("Error sending server message to client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", client.get_username(), client.get_pool_name(), client.get_application_name()))); + } + }; + + self.message_buffer.clear(); + + Ok(()) } /// If the server is still inside a transaction. @@ -593,7 +617,7 @@ impl Server { pub async fn query(&mut self, query: &str) -> Result<(), Error> { let query = simple_query(query); - self.send(query).await?; + self.send(&query).await?; loop { let _ = self.recv().await?; @@ -603,12 +627,20 @@ impl Server { } } + self.message_buffer.clear(); + Ok(()) } /// Perform any necessary cleanup before putting the server /// connection back in the pool pub async fn checkin_cleanup(&mut self) -> Result<(), Error> { + // Incase the message buffer wasn't flushed properly + if !self.message_buffer.is_empty() { + warn!("Server message buffer was not cleated before cleanup"); + self.message_buffer.clear(); + } + // Client disconnected with an open transaction on the server connection. // Pgbouncer behavior is to close the server connection but that can cause // server connection thrashing if clients repeatedly do this.