diff --git a/pgcat.toml b/pgcat.toml index 9e19c13b..ae8dc573 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -77,6 +77,9 @@ admin_username = "admin_user" # Password to access the virtual administrative database admin_password = "admin_pass" +# Enable/disable client TLS +client_tls = false + # Default plugins that are configured on all pools. [plugins] diff --git a/src/client.rs b/src/client.rs index 23392b73..43c2c86e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -114,6 +114,7 @@ pub struct Client { } /// Client entrypoint. +#[allow(clippy::too_many_arguments)] pub async fn client_entrypoint( mut stream: TcpStream, client_server_map: ClientServerMap, @@ -122,6 +123,7 @@ pub async fn client_entrypoint( admin_only: bool, tls_certificate: Option, log_client_connections: bool, + client_tls: bool, ) -> Result<(), Error> { // Figure out if the client wants TLS or not. let addr = match stream.peer_addr() { @@ -134,6 +136,13 @@ pub async fn client_entrypoint( } }; + if client_tls && tls_certificate.is_none() { + error!("Client tls is required but no certificate passed"); + return Err(Error::ClientSSLError( + "Client tls is required but no certificate passed".into(), + )); + } + match get_startup::(&mut stream).await { // Client requested a TLS connection. Ok((ClientConnectionType::Tls, _)) => { @@ -239,6 +248,12 @@ pub async fn client_entrypoint( // Client wants to use plain connection without encryption. Ok((ClientConnectionType::Startup, bytes)) => { + // Check if Client TLS is compulsory + if client_tls { + error!("TLS is required for client connections."); + return Err(Error::TlsError); + } + let (read, write) = split(stream); // Continue with regular startup. diff --git a/src/config.rs b/src/config.rs index ef7952f2..cc407ab6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -331,6 +331,9 @@ pub struct General { #[serde(default)] // false pub verify_server_certificate: bool, + #[serde(default)] //false + pub client_tls: bool, + pub admin_username: String, pub admin_password: String, @@ -460,6 +463,7 @@ impl Default for General { auth_query: None, auth_query_user: None, auth_query_password: None, + client_tls: false, } } } diff --git a/src/errors.rs b/src/errors.rs index 13047b4b..3022f984 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -6,6 +6,7 @@ pub enum Error { SocketError(String), ClientSocketError(String, ClientIdentifier), ClientGeneralError(String, ClientIdentifier), + ClientSSLError(String), ClientAuthImpossible(String), ClientAuthPassthroughError(String, ClientIdentifier), ClientBadStartup, diff --git a/src/main.rs b/src/main.rs index 6c8c1654..4394aa26 100644 --- a/src/main.rs +++ b/src/main.rs @@ -287,6 +287,7 @@ fn main() -> Result<(), Box> { admin_only, tls_certificate, config.general.log_client_connections, + config.general.client_tls, ) .await { diff --git a/src/server.rs b/src/server.rs index 882450ea..aebede1b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,5 @@ /// Implementation of the PostgreSQL server (database) protocol. -/// Here we are pretending to the a Postgres client. +/// Here we are pretending to be a client to the Postgres server. use bytes::{Buf, BufMut, BytesMut}; use fallible_iterator::FallibleIterator; use log::{debug, error, info, trace, warn}; diff --git a/tests/go/pgcat_ssl.toml b/tests/go/pgcat_ssl.toml new file mode 100644 index 00000000..3b55cabe --- /dev/null +++ b/tests/go/pgcat_ssl.toml @@ -0,0 +1,164 @@ +# +# PgCat config example. +# + +# +# General pooler settings +[general] +# What IP to run on, 0.0.0.0 means accessible from everywhere. +host = "0.0.0.0" + +# Port to run on, same as PgBouncer used in this example. +port = "${PORT}" + +# Whether to enable prometheus exporter or not. +enable_prometheus_exporter = true + +# Port at which prometheus exporter listens on. +prometheus_exporter_port = 9930 + +# How long to wait before aborting a server connection (ms). +connect_timeout = 1000 + +# How much time to give the health check query to return with a result (ms). +healthcheck_timeout = 1000 + +# How long to keep connection available for immediate re-use, without running a healthcheck query on it +healthcheck_delay = 30000 + +# How much time to give clients during shutdown before forcibly killing client connections (ms). +shutdown_timeout = 5000 + +# For how long to ban a server if it fails a health check (seconds). +ban_time = 60 # Seconds + +# If we should log client connections +log_client_connections = false + +# If we should log client disconnections +log_client_disconnections = false + +# Reload config automatically if it changes. +autoreload = 15000 + +server_round_robin = false + +# TLS +tls_certificate = "../../.circleci/server.cert" +tls_private_key = "../../.circleci/server.key" + +# Credentials to access the virtual administrative database (pgbouncer or pgcat) +# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. +admin_username = "admin_user" +admin_password = "admin_pass" + +client_tls = true + +# pool +# configs are structured as pool. +# the pool_name is what clients use as database name when connecting +# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db" +[pools.sharded_db] +# Pool mode (see PgBouncer docs for more). +# session: one server connection per connected client +# transaction: one server connection per client transaction +pool_mode = "transaction" + +# If the client doesn't specify, route traffic to +# this role by default. +# +# any: round-robin between primary and replicas, +# replica: round-robin between replicas only without touching the primary, +# primary: all queries go to the primary unless otherwise specified. +default_role = "any" + +# Query parser. If enabled, we'll attempt to parse +# every incoming query to determine if it's a read or a write. +# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, +# we'll direct it to the primary. +query_parser_enabled = true + +# If the query parser is enabled and this setting is enabled, we'll attempt to +# infer the role from the query itself. +query_parser_read_write_splitting = true + +# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for +# load balancing of read queries. Otherwise, the primary will only be used for write +# queries. The primary can always be explicitely selected with our custom protocol. +primary_reads_enabled = true + +# So what if you wanted to implement a different hashing function, +# or you've already built one and you want this pooler to use it? +# +# Current options: +# +# pg_bigint_hash: PARTITION BY HASH (Postgres hashing function) +# sha1: A hashing function based on SHA1 +# +sharding_function = "pg_bigint_hash" + +# Prepared statements cache size. +prepared_statements_cache_size = 500 + +# Credentials for users that may connect to this cluster +[pools.sharded_db.users.0] +username = "sharding_user" +password = "sharding_user" +# Maximum number of server connections that can be established for this user +# The maximum number of connection from a single Pgcat process to any database in the cluster +# is the sum of pool_size across all users. +pool_size = 5 +statement_timeout = 0 + + +[pools.sharded_db.users.1] +username = "other_user" +password = "other_user" +pool_size = 21 +statement_timeout = 30000 + +# Shard 0 +[pools.sharded_db.shards.0] +# [ host, port, role ] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ] +] +# Database name (e.g. "postgres") +database = "shard0" + +[pools.sharded_db.shards.1] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], +] +database = "shard1" + +[pools.sharded_db.shards.2] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], +] +database = "shard2" + + +[pools.simple_db] +pool_mode = "session" +default_role = "primary" +query_parser_enabled = true +query_parser_read_write_splitting = true +primary_reads_enabled = true +sharding_function = "pg_bigint_hash" + +[pools.simple_db.users.0] +username = "simple_user" +password = "simple_user" +pool_size = 5 +statement_timeout = 30000 + +[pools.simple_db.shards.0] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ] +] +database = "some_db" \ No newline at end of file diff --git a/tests/go/prepared_test.go b/tests/go/prepared_test.go index 0a42e721..c8bd7085 100644 --- a/tests/go/prepared_test.go +++ b/tests/go/prepared_test.go @@ -4,12 +4,13 @@ import ( "context" "database/sql" "fmt" - _ "github.com/lib/pq" "testing" + + _ "github.com/lib/pq" ) func Test(t *testing.T) { - t.Cleanup(setup(t)) + t.Cleanup(setupNonTls(t)) t.Run("Named parameterized prepared statement works", namedParameterizedPreparedStatement) t.Run("Unnamed parameterized prepared statement works", unnamedParameterizedPreparedStatement) } diff --git a/tests/go/setup.go b/tests/go/setup.go index 32ffc4ba..ca29ae8b 100644 --- a/tests/go/setup.go +++ b/tests/go/setup.go @@ -16,9 +16,74 @@ import ( //go:embed pgcat.toml var pgcatCfg string +//go:embed pgcat_ssl.toml +var pgcatTlsCfg string + var port = rand.Intn(32760-20000) + 20000 +var ssl_port = port + 1 + +func setupTls(t *testing.T) func() { + cfg, err := os.CreateTemp("/tmp", "pgcat_ssl_cfg_*.toml") + if err != nil { + t.Fatalf("could not create temp file: %+v", err) + } + pgcatTlsCfg = strings.Replace(pgcatTlsCfg, "\"${PORT}\"", fmt.Sprintf("%d", ssl_port), 1) + + _, err = cfg.Write([]byte(pgcatTlsCfg)) + if err != nil { + t.Fatalf("could not write temp file: %+v", err) + } + + commandPath := "../../target/debug/pgcat" + if os.Getenv("CARGO_TARGET_DIR") != "" { + commandPath = os.Getenv("CARGO_TARGET_DIR") + "/debug/pgcat" + } + + cmd := exec.Command(commandPath, cfg.Name()) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + go func() { + err = cmd.Run() + if err != nil { + t.Errorf("could not run pgcat: %+v", err) + } + }() + + deadline, cancelFunc := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + defer cancelFunc() + for { + select { + case <-deadline.Done(): + break + case <-time.After(50 * time.Millisecond): + db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=pgcat user=admin_user password=admin_pass sslmode=require sslcert=../../.circleci/server.cert sslkey=../../.circleci/server.key", ssl_port)) + if err != nil { + continue + } + rows, err := db.QueryContext(deadline, "SHOW STATS") + if err != nil { + continue + } + _ = rows.Close() + _ = db.Close() + break + } + break + } + + return func() { + err := cmd.Process.Signal(os.Interrupt) + if err != nil { + t.Fatalf("could not interrupt pgcat: %+v", err) + } + err = os.Remove(cfg.Name()) + if err != nil { + t.Fatalf("could not remove temp file: %+v", err) + } + } +} -func setup(t *testing.T) func() { +func setupNonTls(t *testing.T) func() { cfg, err := os.CreateTemp("/tmp", "pgcat_cfg_*.toml") if err != nil { t.Fatalf("could not create temp file: %+v", err) diff --git a/tests/go/ssl_test.go b/tests/go/ssl_test.go new file mode 100644 index 00000000..3f533b5f --- /dev/null +++ b/tests/go/ssl_test.go @@ -0,0 +1,51 @@ +package pgcat + +import ( + "database/sql" + "fmt" + "testing" + + _ "github.com/lib/pq" +) + +func TestSSL(t *testing.T) { + t.Cleanup((setupTls(t))) + t.Run("Prepared Statement and Query on SSL connection works", namedParameterizedPreparedStatementOnSSL) + t.Run("Connection without ssl params fails", connectionWithoutSSLParams) +} + +func namedParameterizedPreparedStatementOnSSL(t *testing.T) { + db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=require", ssl_port)) + if err != nil { + t.Fatalf("could not open connection: %+v", err) + } + + stmt, err := db.Prepare("SELECT $1") + + if err != nil { + t.Fatalf("could not prepare: %+v", err) + } + + for i := 0; i < 100; i++ { + rows, err := stmt.Query(1) + if err != nil { + t.Fatalf("could not query: %+v", err) + } + _ = rows.Close() + } + + defer db.Close() + +} + +func connectionWithoutSSLParams(t *testing.T) { + db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", ssl_port)) + if err != nil { + t.Fatalf("could not open connection: %+v", err) + } + err = db.Ping() + if err == nil { + t.Fatalf("Non TLS Client connection established on a server that requires TLS") + } + +}