Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pgcat.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ admin_username = "admin_user"
# Password to access the virtual administrative database
admin_password = "admin_pass"

# Enable/disable client TLS
client_tls = false

# Default plugins that are configured on all pools.
[plugins]

Expand Down
15 changes: 15 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pub struct Client<S, T> {
}

/// Client entrypoint.
#[allow(clippy::too_many_arguments)]
pub async fn client_entrypoint(
mut stream: TcpStream,
client_server_map: ClientServerMap,
Expand All @@ -122,6 +123,7 @@ pub async fn client_entrypoint(
admin_only: bool,
tls_certificate: Option<String>,
log_client_connections: bool,
client_tls: bool,
) -> Result<(), Error> {
// Figure out if the client wants TLS or not.
let addr = match stream.peer_addr() {
Expand All @@ -134,6 +136,13 @@ pub async fn client_entrypoint(
}
};

if client_tls && tls_certificate.is_none() {
error!("Client tls is required but no certificate passed");
return Err(Error::ClientSSLError(
"Client tls is required but no certificate passed".into(),
));
}

match get_startup::<TcpStream>(&mut stream).await {
// Client requested a TLS connection.
Ok((ClientConnectionType::Tls, _)) => {
Expand Down Expand Up @@ -239,6 +248,12 @@ pub async fn client_entrypoint(

// Client wants to use plain connection without encryption.
Ok((ClientConnectionType::Startup, bytes)) => {
// Check if Client TLS is compulsory
if client_tls {
error!("TLS is required for client connections.");
return Err(Error::TlsError);
}

let (read, write) = split(stream);

// Continue with regular startup.
Expand Down
4 changes: 4 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ pub struct General {
#[serde(default)] // false
pub verify_server_certificate: bool,

#[serde(default)] //false
pub client_tls: bool,

pub admin_username: String,
pub admin_password: String,

Expand Down Expand Up @@ -460,6 +463,7 @@ impl Default for General {
auth_query: None,
auth_query_user: None,
auth_query_password: None,
client_tls: false,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub enum Error {
SocketError(String),
ClientSocketError(String, ClientIdentifier),
ClientGeneralError(String, ClientIdentifier),
ClientSSLError(String),
ClientAuthImpossible(String),
ClientAuthPassthroughError(String, ClientIdentifier),
ClientBadStartup,
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
admin_only,
tls_certificate,
config.general.log_client_connections,
config.general.client_tls,
)
.await
{
Expand Down
2 changes: 1 addition & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/// Implementation of the PostgreSQL server (database) protocol.
/// Here we are pretending to the a Postgres client.
/// Here we are pretending to be a client to the Postgres server.
use bytes::{Buf, BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use log::{debug, error, info, trace, warn};
Expand Down
164 changes: 164 additions & 0 deletions tests/go/pgcat_ssl.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#
# PgCat config example.
#

#
# General pooler settings
[general]
# What IP to run on, 0.0.0.0 means accessible from everywhere.
host = "0.0.0.0"

# Port to run on, same as PgBouncer used in this example.
port = "${PORT}"

# Whether to enable prometheus exporter or not.
enable_prometheus_exporter = true

# Port at which prometheus exporter listens on.
prometheus_exporter_port = 9930

# How long to wait before aborting a server connection (ms).
connect_timeout = 1000

# How much time to give the health check query to return with a result (ms).
healthcheck_timeout = 1000

# How long to keep connection available for immediate re-use, without running a healthcheck query on it
healthcheck_delay = 30000

# How much time to give clients during shutdown before forcibly killing client connections (ms).
shutdown_timeout = 5000

# For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds

# If we should log client connections
log_client_connections = false

# If we should log client disconnections
log_client_disconnections = false

# Reload config automatically if it changes.
autoreload = 15000

server_round_robin = false

# TLS
tls_certificate = "../../.circleci/server.cert"
tls_private_key = "../../.circleci/server.key"

# Credentials to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
admin_username = "admin_user"
admin_password = "admin_pass"

client_tls = true

# pool
# configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting
# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db"
[pools.sharded_db]
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"

# If the client doesn't specify, route traffic to
# this role by default.
#
# any: round-robin between primary and replicas,
# replica: round-robin between replicas only without touching the primary,
# primary: all queries go to the primary unless otherwise specified.
default_role = "any"

# Query parser. If enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
# we'll direct it to the primary.
query_parser_enabled = true

# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true

# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol.
primary_reads_enabled = true

# So what if you wanted to implement a different hashing function,
# or you've already built one and you want this pooler to use it?
#
# Current options:
#
# pg_bigint_hash: PARTITION BY HASH (Postgres hashing function)
# sha1: A hashing function based on SHA1
#
sharding_function = "pg_bigint_hash"

# Prepared statements cache size.
prepared_statements_cache_size = 500

# Credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
username = "sharding_user"
password = "sharding_user"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 5
statement_timeout = 0


[pools.sharded_db.users.1]
username = "other_user"
password = "other_user"
pool_size = 21
statement_timeout = 30000

# Shard 0
[pools.sharded_db.shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
# Database name (e.g. "postgres")
database = "shard0"

[pools.sharded_db.shards.1]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard1"

[pools.sharded_db.shards.2]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard2"


[pools.simple_db]
pool_mode = "session"
default_role = "primary"
query_parser_enabled = true
query_parser_read_write_splitting = true
primary_reads_enabled = true
sharding_function = "pg_bigint_hash"

[pools.simple_db.users.0]
username = "simple_user"
password = "simple_user"
pool_size = 5
statement_timeout = 30000

[pools.simple_db.shards.0]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
database = "some_db"
5 changes: 3 additions & 2 deletions tests/go/prepared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ import (
"context"
"database/sql"
"fmt"
_ "github.com/lib/pq"
"testing"

_ "github.com/lib/pq"
)

func Test(t *testing.T) {
t.Cleanup(setup(t))
t.Cleanup(setupNonTls(t))
t.Run("Named parameterized prepared statement works", namedParameterizedPreparedStatement)
t.Run("Unnamed parameterized prepared statement works", unnamedParameterizedPreparedStatement)
}
Expand Down
67 changes: 66 additions & 1 deletion tests/go/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,74 @@ import (
//go:embed pgcat.toml
var pgcatCfg string

//go:embed pgcat_ssl.toml
var pgcatTlsCfg string

var port = rand.Intn(32760-20000) + 20000
var ssl_port = port + 1

func setupTls(t *testing.T) func() {
cfg, err := os.CreateTemp("/tmp", "pgcat_ssl_cfg_*.toml")
if err != nil {
t.Fatalf("could not create temp file: %+v", err)
}
pgcatTlsCfg = strings.Replace(pgcatTlsCfg, "\"${PORT}\"", fmt.Sprintf("%d", ssl_port), 1)

_, err = cfg.Write([]byte(pgcatTlsCfg))
if err != nil {
t.Fatalf("could not write temp file: %+v", err)
}

commandPath := "../../target/debug/pgcat"
if os.Getenv("CARGO_TARGET_DIR") != "" {
commandPath = os.Getenv("CARGO_TARGET_DIR") + "/debug/pgcat"
}

cmd := exec.Command(commandPath, cfg.Name())
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
go func() {
err = cmd.Run()
if err != nil {
t.Errorf("could not run pgcat: %+v", err)
}
}()

deadline, cancelFunc := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer cancelFunc()
for {
select {
case <-deadline.Done():
break
case <-time.After(50 * time.Millisecond):
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=pgcat user=admin_user password=admin_pass sslmode=require sslcert=../../.circleci/server.cert sslkey=../../.circleci/server.key", ssl_port))
if err != nil {
continue
}
rows, err := db.QueryContext(deadline, "SHOW STATS")
if err != nil {
continue
}
_ = rows.Close()
_ = db.Close()
break
}
break
}

return func() {
err := cmd.Process.Signal(os.Interrupt)
if err != nil {
t.Fatalf("could not interrupt pgcat: %+v", err)
}
err = os.Remove(cfg.Name())
if err != nil {
t.Fatalf("could not remove temp file: %+v", err)
}
}
}

func setup(t *testing.T) func() {
func setupNonTls(t *testing.T) func() {
cfg, err := os.CreateTemp("/tmp", "pgcat_cfg_*.toml")
if err != nil {
t.Fatalf("could not create temp file: %+v", err)
Expand Down
Loading