Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions experimental/postgres/cmd/cancel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package postgrescmd

import (
"context"
"errors"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestWithStatementTimeout_ZeroIsPassthrough(t *testing.T) {
parent := t.Context()
got, cancel := withStatementTimeout(parent, 0)
defer cancel()
// Parent and got should compare equal: zero timeout returns the parent
// unchanged (and a no-op cancel).
deadline, ok := got.Deadline()
assert.False(t, ok, "deadline should not be set when timeout is 0")
assert.True(t, deadline.IsZero())
}

func TestWithStatementTimeout_AppliesDeadline(t *testing.T) {
parent := t.Context()
got, cancel := withStatementTimeout(parent, time.Second)
defer cancel()
deadline, ok := got.Deadline()
assert.True(t, ok)
assert.False(t, deadline.IsZero())
}

func TestReportCancellation_SignalCanceled(t *testing.T) {
signalCtx, signalCancel := context.WithCancel(t.Context())
signalCancel()
stmtCtx := signalCtx
msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("anything"), 0)
assert.Equal(t, "Query cancelled.", msg)
assert.True(t, invocationScoped)
}

func TestReportCancellation_TimeoutFired(t *testing.T) {
signalCtx := t.Context()
stmtCtx, stmtCancel := context.WithDeadline(signalCtx, time.Now().Add(-time.Second))
defer stmtCancel()
<-stmtCtx.Done()
msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("query failed"), 5*time.Second)
assert.Equal(t, "Query timed out after 5s.", msg)
assert.True(t, invocationScoped)
}

func TestReportCancellation_GenericError(t *testing.T) {
signalCtx := t.Context()
stmtCtx := signalCtx
msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("syntax error"), 0)
assert.Equal(t, "syntax error", msg)
assert.False(t, invocationScoped)
}

func TestReportCancellation_BothFire_CancelWinsRace(t *testing.T) {
// User cancel and deadline both already done. Precedence: cancel wins
// (the user's intent dominates a coincidental deadline). A future
// reordering of the switch would silently flip this; the test pins it.
signalCtx, signalCancel := context.WithCancel(t.Context())
signalCancel()
stmtCtx, stmtCancel := context.WithDeadline(signalCtx, time.Now().Add(-time.Second))
defer stmtCancel()
<-stmtCtx.Done()
msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("anything"), time.Second)
assert.Equal(t, "Query cancelled.", msg)
assert.True(t, invocationScoped)
}

func TestWatchInterruptSignals_CancelOnStop(t *testing.T) {
// stop should cancel the parent context as a side-effect so the goroutine
// terminates promptly. We don't actually send a SIGINT here (it would
// also kill the test runner); we just verify stop cleans up.
parent, parentCancel := context.WithCancel(t.Context())
defer parentCancel()

cancelled := false
cancel := func() {
cancelled = true
parentCancel()
}

stop := watchInterruptSignals(parent, cancel)
stop()
assert.True(t, cancelled, "stop should call cancel to wake the goroutine")
}
22 changes: 22 additions & 0 deletions experimental/postgres/cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/databricks/cli/libs/log"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
)

// defaultConnectTimeout is the dial timeout for a single connect attempt.
Expand Down Expand Up @@ -52,6 +53,19 @@ type connectFunc func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, erro
// in the resolved values. The DSN-then-patch pattern is the recommended way
// to configure pgx for `sslmode=require` because building a pgx.ConnConfig
// by hand omits internal fields that the parser sets.
//
// The context-watcher handler is overridden so context cancellation issues
// a Postgres CancelRequest on the side-channel rather than only closing the
// underlying TCP connection. Without this override, a Ctrl+C during a long
// SELECT would tear down the TCP socket but leave the server-side query
// running until it noticed the broken connection on its next write.
//
// CancelRequestDelay = 0: send the cancel-request immediately on ctx cancel.
// The user just hit Ctrl+C; we want the server to learn now.
// DeadlineDelay = 5s: if the cancel-request has not gotten the server to
// terminate the query within 5s, fall back to deadlining the connection.
// Zero DeadlineDelay would race the cancel-request and could leave the
// connection unusable.
func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) {
cfg, err := pgx.ParseConfig("postgresql:///?sslmode=require")
if err != nil {
Expand All @@ -63,6 +77,14 @@ func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) {
cfg.Password = c.Password
cfg.Database = c.Database
cfg.ConnectTimeout = c.ConnectTimeout

cfg.BuildContextWatcherHandler = func(pgc *pgconn.PgConn) ctxwatch.Handler {
return &pgconn.CancelRequestContextWatcherHandler{
Conn: pgc,
CancelRequestDelay: 0,
DeadlineDelay: 5 * time.Second,
}
}
return cfg, nil
}

Expand Down
90 changes: 74 additions & 16 deletions experimental/postgres/cmd/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type queryFlags struct {
connectTimeout time.Duration
maxRetries int
files []string
timeout time.Duration

// outputFormat is the raw flag value. resolveOutputFormat turns it into
// the effective format (which may differ when stdout is piped).
Expand Down Expand Up @@ -96,6 +97,7 @@ Limitations (this release):
cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name")
cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout")
cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)")
cmd.Flags().DurationVar(&f.timeout, "timeout", 0, "Per-statement timeout (0 disables)")
cmd.Flags().StringArrayVarP(&f.files, "file", "f", nil, "SQL file path (repeatable)")
cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(sqlcli.OutputText), "Output format: text, json, or csv")
cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
Expand Down Expand Up @@ -178,10 +180,21 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla
MaxDelay: 10 * time.Second,
}

conn, err := connectWithRetry(ctx, pgxCfg, rc, pgx.ConnectConfig)
// Invocation-scoped context: cancelled by Ctrl+C/SIGTERM. Owns the
// connection lifecycle. Per-statement timeouts are children of this so
// a cancelled invocation also cancels the in-flight statement.
signalCtx, signalCancel := context.WithCancel(ctx)
defer signalCancel()

stopSignals := watchInterruptSignals(signalCtx, signalCancel)
defer stopSignals()

conn, err := connectWithRetry(signalCtx, pgxCfg, rc, pgx.ConnectConfig)
if err != nil {
return err
}
// Close on a background ctx so a cancelled signalCtx does not abort a
// clean teardown handshake.
defer conn.Close(context.WithoutCancel(ctx))

out := cmd.OutOrStdout()
Expand All @@ -192,9 +205,16 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla
// Avoids buffering rows for large exports and matches the v1 single-
// input behaviour PR 2 shipped. Wrap the error so DETAIL / HINT
// from a *pgconn.PgError surface even on the single-input path.
sink := newSink(format, out, stderr)
if err := executeOne(ctx, conn, units[0].SQL, sink); err != nil {
return errors.New(formatPgError(err))
// Promote-to-interactive only when stdout is a prompt-capable TTY so
// a pipe falls back to the static table rather than launching a TUI
// into a dead writer.
sink := newSinkInteractive(format, out, stderr, stdoutTTY && cmdio.IsPromptSupported(ctx))
stmtCtx, stmtCancel := withStatementTimeout(signalCtx, f.timeout)
err := executeOne(stmtCtx, conn, units[0].SQL, sink)
stmtCancel()
if err != nil {
msg, _ := reportCancellation(signalCtx, stmtCtx, err, f.timeout)
return errors.New(msg)
}
return nil
}
Expand All @@ -205,7 +225,9 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla
// temp tables) carries across units because we hold the same connection.
results := make([]*unitResult, 0, len(units))
for _, u := range units {
r, err := runUnitBuffered(ctx, conn, u)
stmtCtx, stmtCancel := withStatementTimeout(signalCtx, f.timeout)
r, err := runUnitBuffered(stmtCtx, conn, u)
stmtCancel()
if err != nil {
// Render the successful prefix, then surface the error with
// rich pgError formatting if applicable.
Expand All @@ -214,7 +236,14 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla
// error to the user, the renderer error to debug logs.
fmt.Fprintln(stderr, "warning: failed to render partial result:", rerr)
}
return formatExecutionError(u.Source, err)
msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, err, f.timeout)
if invocationScoped {
// User cancel / timeout is invocation-scoped; the source
// prefix is redundant ("--file foo.sql: Query cancelled."
// reads worse than just "Query cancelled.").
return errors.New(msg)
}
return errors.New(u.Source + ": " + msg)
}
results = append(results, r)
}
Expand All @@ -227,15 +256,51 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla
}
}

// newSink returns the rowSink for the chosen output format. Kept separate
// from runQuery so tests can build sinks without going through pgx.
func newSink(format sqlcli.Format, out, stderr io.Writer) rowSink {
// withStatementTimeout returns ctx unchanged (and a no-op cancel) when
// timeout is zero, otherwise a child context with the timeout applied. Each
// statement gets its own deadline so cancellation is scoped to one
// statement at a time.
func withStatementTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
if timeout <= 0 {
return parent, func() {}
}
return context.WithTimeout(parent, timeout)
}

// reportCancellation distinguishes the three error cases that look the same
// from `executeOne`'s POV (a wrapped pgconn / network error): user cancelled
// via Ctrl+C, --timeout fired, or the statement just plain errored. Returns
// the human-readable message and whether the cause is invocation-scoped
// (cancel/timeout) rather than statement-scoped.
//
// Precedence: user cancel beats deadline. If both contexts fire near-
// simultaneously (race), we report "cancelled" because the user's intent
// dominates a coincidental timeout.
func reportCancellation(signalCtx, stmtCtx context.Context, err error, timeout time.Duration) (msg string, invocationScoped bool) {
switch {
case errors.Is(signalCtx.Err(), context.Canceled):
return "Query cancelled.", true
case timeout > 0 && errors.Is(stmtCtx.Err(), context.DeadlineExceeded):
return fmt.Sprintf("Query timed out after %s.", timeout), true
default:
return formatPgError(err), false
}
}

// newSinkInteractive returns the rowSink for the chosen output format. When
// interactive is true the text sink may launch the libs/tableview viewer for
// results larger than staticTableThreshold; when false it uses the static
// tabwriter table.
func newSinkInteractive(format sqlcli.Format, out, stderr io.Writer, interactive bool) rowSink {
switch format {
case sqlcli.OutputJSON:
return newJSONSink(out, stderr)
case sqlcli.OutputCSV:
return newCSVSink(out, stderr)
default:
if interactive {
return newInteractiveTextSink(out)
}
return newTextSink(out)
}
}
Expand All @@ -254,13 +319,6 @@ func renderPartial(out, stderr io.Writer, format sqlcli.Format, results []*unitR
}
}

// formatExecutionError produces the error returned to cobra when an input
// unit failed. The message includes the source label so the user knows
// which of N inputs blew up.
func formatExecutionError(source string, err error) error {
return fmt.Errorf("%s: %s", source, formatPgError(err))
}

// multiStatementHint is appended to errMultipleStatements so users see the
// recovery path inline.
const multiStatementHint = "\nThis command runs one statement per input. To run multiple statements:\n" +
Expand Down
35 changes: 32 additions & 3 deletions experimental/postgres/cmd/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,44 @@ import (
"strings"
"text/tabwriter"

"github.com/databricks/cli/libs/tableview"
"github.com/jackc/pgx/v5/pgconn"
)

// staticTableThreshold is the row count above which we hand off to
// libs/tableview's interactive viewer (when stdout is interactive). Smaller
// results stay in the static tabwriter path so they stream to a pipe
// unchanged. Matches the threshold aitools query uses.
const staticTableThreshold = 30

// textSink renders results as plain text: a tabwriter-aligned table for
// rows-producing statements, the command tag for command-only ones.
//
// Text output buffers all rows because tabwriter needs the widest cell in each
// column before it can align. Streaming output is provided by the JSON and CSV
// sinks; users with huge result sets should pick those.
//
// When interactive is true and the result has more than staticTableThreshold
// rows, End hands off to libs/tableview's scrollable viewer instead of
// emitting the static table. The interactive path requires a real TTY and a
// prompt-capable terminal; the caller decides.
type textSink struct {
out io.Writer
columns []string
rows [][]string
out io.Writer
interactive bool
columns []string
rows [][]string
}

func newTextSink(out io.Writer) *textSink {
return &textSink{out: out}
}

// newInteractiveTextSink returns a text sink that uses the interactive table
// viewer for results larger than staticTableThreshold.
func newInteractiveTextSink(out io.Writer) *textSink {
return &textSink{out: out, interactive: true}
}

func (s *textSink) Begin(fields []pgconn.FieldDescription) error {
s.columns = make([]string, len(fields))
for i, f := range fields {
Expand Down Expand Up @@ -61,6 +80,16 @@ func (s *textSink) End(commandTag string) error {
return err
}

if s.interactive && len(s.rows) > staticTableThreshold {
// Try the interactive viewer; on failure (TUI startup, terminal
// resize race, etc.) fall through to the static path so the user
// still sees the rows their query returned. Without this fallback
// a successful query would surface as "viewer failed" with no data.
if err := tableview.Run(s.out, s.columns, s.rows); err == nil {
return nil
}
}

tw := tabwriter.NewWriter(s.out, 0, 0, 2, ' ', 0)
fmt.Fprintln(tw, strings.Join(s.columns, "\t"))
fmt.Fprintln(tw, strings.Join(headerSeparator(s.columns), "\t"))
Expand Down
40 changes: 40 additions & 0 deletions experimental/postgres/cmd/signals.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package postgrescmd

import (
"context"
"os"
"os/signal"
"syscall"
)

// watchInterruptSignals installs handlers for SIGINT and SIGTERM that call
// cancel when the user hits Ctrl+C or the process gets a SIGTERM.
//
// Returns a stop-and-cancel function that uninstalls the handlers (signal.Stop
// prevents future OS deliveries) and cancels the parent context so the
// goroutine wakes promptly. The caller must defer it. The channel is
// 1-buffered and GC'd on return; no explicit drain is needed.
//
// On Windows, Go maps Ctrl+C to os.Interrupt via the console-control-handler,
// so the same code path covers Windows.
func watchInterruptSignals(ctx context.Context, cancel context.CancelFunc) func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)

done := make(chan struct{})
go func() {
select {
case <-sigCh:
cancel()
case <-ctx.Done():
}
close(done)
}()

return func() {
signal.Stop(sigCh)
// Wake the goroutine in case neither sigCh nor ctx.Done has fired.
cancel()
<-done
}
}
Loading