From 495849d64a2b68eb1426743c4a624790fca96129 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:30:21 +0200 Subject: [PATCH 1/3] Multi-input + multi-statement rejection + pg error formatting This is PR 3 of the experimental postgres query stack. Adds the rest of the input ergonomics promised in the plan and the error-formatting polish. Inputs: positional args become variadic, --file is repeatable, stdin is read when neither is present, and a positional ending in '.sql' that exists on disk is treated as a SQL file. Execution order is files-first then positionals (cobra/pflag does not preserve interleaved spelling, documented in --help). Each input unit must contain exactly one statement. checkSingleStatement walks the SQL with a hand-written conservative scanner that ignores ';' inside single-quoted strings, double-quoted identifiers, line comments, block comments, and dollar-quoted bodies. Multi-statement strings are rejected before connect with a hint pointing at the multi-input alternatives. Multi-input output: - text: each per-unit result rendered inline, separated by a blank line (mirrors psql's compact text shape). - json: top-level array of per-unit result objects with shape {"sql","kind","elapsed_ms",...}; rows-producing units carry a "rows":[...] array, command-only carry "command"+"rows_affected". Each per-unit object is buffered to completion before write; the outer array streams across units. The plan accepts this trade-off: huge SELECTs in multi-input invocations buffer. - csv: rejected pre-flight when N>1 (no sensible cross-schema shape). Single-input csv keeps streaming. Per-unit errors render as a {"kind":"error", ...} entry in the JSON shape so scripts can detect failure without checking exit code. Sequential execution stops on the first failing unit; the successful prefix is rendered. formatPgError renders *pgconn.PgError with SEVERITY, SQLSTATE, DETAIL, HINT inline. Non-PgError values pass through unchanged so connect-time errors keep their original wording. Single-input keeps the streaming sinks from PR 2; only multi-input goes through the buffered renderer. Session state (SET, temp tables) carries across input units because they share one connection. TUI for >30 rows is deferred to a follow-up. The current text path uses the static tabwriter table for both single- and multi-input. Co-authored-by: Isaac --- .../postgres/query/argument-errors/output.txt | 23 +- .../postgres/query/argument-errors/script | 14 ++ experimental/postgres/cmd/error.go | 42 ++++ experimental/postgres/cmd/error_test.go | 48 ++++ experimental/postgres/cmd/inputs.go | 102 +++++++++ experimental/postgres/cmd/inputs_test.go | 101 +++++++++ experimental/postgres/cmd/multistatement.go | 159 +++++++++++++ .../postgres/cmd/multistatement_test.go | 54 +++++ experimental/postgres/cmd/query.go | 147 +++++++++--- experimental/postgres/cmd/render_multi.go | 209 ++++++++++++++++++ .../postgres/cmd/render_multi_test.go | 89 ++++++++ experimental/postgres/cmd/result.go | 62 ++++++ 12 files changed, 1017 insertions(+), 33 deletions(-) create mode 100644 experimental/postgres/cmd/error.go create mode 100644 experimental/postgres/cmd/error_test.go create mode 100644 experimental/postgres/cmd/inputs.go create mode 100644 experimental/postgres/cmd/inputs_test.go create mode 100644 experimental/postgres/cmd/multistatement.go create mode 100644 experimental/postgres/cmd/multistatement_test.go create mode 100644 experimental/postgres/cmd/render_multi.go create mode 100644 experimental/postgres/cmd/render_multi_test.go create mode 100644 experimental/postgres/cmd/result.go diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt index 238e099299..3b6fe7910a 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt @@ -1,11 +1,11 @@ === No SQL argument should error: >>> musterr [CLI] experimental postgres query --target projects/foo -Error: accepts 1 arg(s), received 0 +Error: no SQL provided === Empty SQL should error: >>> musterr [CLI] experimental postgres query --target projects/foo -Error: no SQL provided +Error: argv[1] is empty === Neither targeting form should error: >>> musterr [CLI] experimental postgres query SELECT 1 @@ -42,3 +42,22 @@ Error: invalid resource path: missing project ID === Trailing components after endpoint should error: >>> musterr [CLI] experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra SELECT 1 Error: invalid resource path: trailing components after endpoint: projects/foo/branches/bar/endpoints/baz/extra + +=== Multi-statement string should error with hint: +>>> musterr [CLI] experimental postgres query --target projects/foo SELECT 1; SELECT 2 +Error: argv[1]: input contains multiple statements (a ';' separates two or more statements) +This command runs one statement per input. To run multiple statements: + - Pass each as a separate positional: query "SELECT 1" "SELECT 2" + - Pass each in its own --file: query --file q1.sql --file q2.sql + +=== CSV with multiple inputs should reject pre-flight: +>>> musterr [CLI] experimental postgres query --target projects/foo --output csv SELECT 1 SELECT 2 +Error: --output csv requires a single input unit; got 2 (use --output json for multi-input invocations) + +=== Empty file should error: +>>> musterr [CLI] experimental postgres query --target projects/foo --file empty.sql +Error: --file "empty.sql" is empty + +=== Missing file should error: +>>> musterr [CLI] experimental postgres query --target projects/foo --file /tmp/does-not-exist.sql +Error: read --file "/tmp/does-not-exist.sql": open /tmp/does-not-exist.sql: no such file or directory diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/script b/acceptance/cmd/experimental/postgres/query/argument-errors/script index ac6ac42746..a1401d3b8e 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/script +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/script @@ -30,3 +30,17 @@ trace musterr $CLI experimental postgres query --target projects/ "SELECT 1" title "Trailing components after endpoint should error:" trace musterr $CLI experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra "SELECT 1" + +title "Multi-statement string should error with hint:" +trace musterr $CLI experimental postgres query --target projects/foo "SELECT 1; SELECT 2" + +title "CSV with multiple inputs should reject pre-flight:" +trace musterr $CLI experimental postgres query --target projects/foo --output csv "SELECT 1" "SELECT 2" + +title "Empty file should error:" +echo "" > empty.sql +trace musterr $CLI experimental postgres query --target projects/foo --file empty.sql +rm -f empty.sql + +title "Missing file should error:" +trace musterr $CLI experimental postgres query --target projects/foo --file /tmp/does-not-exist.sql diff --git a/experimental/postgres/cmd/error.go b/experimental/postgres/cmd/error.go new file mode 100644 index 0000000000..02278a6c58 --- /dev/null +++ b/experimental/postgres/cmd/error.go @@ -0,0 +1,42 @@ +package postgrescmd + +import ( + "errors" + "fmt" + "strings" + + "github.com/jackc/pgx/v5/pgconn" +) + +// formatPgError renders an error in a friendlier form when it's a Postgres +// server-side error. *pgconn.PgError exposes Code, Severity, Message, Detail, +// Hint, and Position; the plain text form attaches what's set so users see +// SQLSTATE plus any hint upstream included. +// +// For non-PgError values, returns err.Error() unchanged so the caller can +// surface it directly. The richer LINE+caret rendering is out of scope for +// this PR; we stick with the plain shape for now. +func formatPgError(err error) string { + var pgErr *pgconn.PgError + if !errors.As(err, &pgErr) { + return err.Error() + } + + var sb strings.Builder + if pgErr.Severity != "" { + fmt.Fprintf(&sb, "%s: ", pgErr.Severity) + } else { + sb.WriteString("ERROR: ") + } + sb.WriteString(pgErr.Message) + if pgErr.Code != "" { + fmt.Fprintf(&sb, " (SQLSTATE %s)", pgErr.Code) + } + if pgErr.Detail != "" { + fmt.Fprintf(&sb, "\nDETAIL: %s", pgErr.Detail) + } + if pgErr.Hint != "" { + fmt.Fprintf(&sb, "\nHINT: %s", pgErr.Hint) + } + return sb.String() +} diff --git a/experimental/postgres/cmd/error_test.go b/experimental/postgres/cmd/error_test.go new file mode 100644 index 0000000000..f4d709468d --- /dev/null +++ b/experimental/postgres/cmd/error_test.go @@ -0,0 +1,48 @@ +package postgrescmd + +import ( + "errors" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" +) + +func TestFormatPgError_NonPgError(t *testing.T) { + err := errors.New("plain error") + assert.Equal(t, "plain error", formatPgError(err)) +} + +func TestFormatPgError_BasicPgError(t *testing.T) { + err := &pgconn.PgError{ + Severity: "ERROR", + Code: "42601", + Message: `syntax error at or near "FRO"`, + } + assert.Equal(t, + `ERROR: syntax error at or near "FRO" (SQLSTATE 42601)`, + formatPgError(err), + ) +} + +func TestFormatPgError_WithDetailAndHint(t *testing.T) { + err := &pgconn.PgError{ + Severity: "ERROR", + Code: "42601", + Message: `syntax error at or near "FRO"`, + Hint: `Did you mean "FROM"?`, + Detail: "more context", + } + got := formatPgError(err) + assert.Contains(t, got, "ERROR:") + assert.Contains(t, got, "(SQLSTATE 42601)") + assert.Contains(t, got, "DETAIL: more context") + assert.Contains(t, got, `HINT: Did you mean "FROM"?`) +} + +func TestFormatPgError_WrappedPgError(t *testing.T) { + pg := &pgconn.PgError{Code: "42501", Message: "permission denied"} + wrapped := errors.New("query failed: " + pg.Error()) + // Plain error doesn't unwrap; falls through to err.Error. + assert.Contains(t, formatPgError(wrapped), "permission denied") +} diff --git a/experimental/postgres/cmd/inputs.go b/experimental/postgres/cmd/inputs.go new file mode 100644 index 0000000000..3cc64d45ad --- /dev/null +++ b/experimental/postgres/cmd/inputs.go @@ -0,0 +1,102 @@ +package postgrescmd + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "strings" + + "github.com/databricks/cli/libs/cmdio" +) + +// sqlFileExtension is the file suffix that triggers the .sql autodetect on a +// positional argument: if `databricks ... query foo.sql` exists on disk, we +// read it as a SQL file; otherwise it's treated as literal SQL. +const sqlFileExtension = ".sql" + +// inputUnit is one SQL statement to execute, paired with metadata so the +// renderer can identify its origin in multi-input output shapes. +type inputUnit struct { + // SQL is the trimmed statement text. Always non-empty by the time the + // scanner has rejected multi-statement strings and empty inputs. + SQL string + // Source is a human-readable label for this input ("--file query.sql", + // "stdin", or "argv[1]"). Used by the multi-input JSON renderer's "sql" + // field hint and by the rich error formatter. + Source string +} + +// collectInputs assembles the ordered list of input units from positional +// arguments, --file flags, and stdin. +// +// Execution order is files-first then positionals (plan section "Statement +// execution"). Cobra/pflag does not preserve the user's interleaved CLI +// spelling: it collects all --file flags into one slice and all positionals +// into another, so we cannot honour `--file q1.sql "SELECT 1" --file q2.sql` +// as written. This is documented in --help. +// +// Stdin is read only when neither positional nor --file is provided. +func collectInputs(ctx context.Context, in io.Reader, args, files []string) ([]inputUnit, error) { + var units []inputUnit + + for _, path := range files { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read --file %q: %w", path, err) + } + sql := strings.TrimSpace(string(data)) + if sql == "" { + return nil, fmt.Errorf("--file %q is empty", path) + } + units = append(units, inputUnit{SQL: sql, Source: "--file " + path}) + } + + for i, arg := range args { + // .sql autodetect: if the positional ends in .sql AND the file + // exists, read it as a SQL file. Other read errors (permission + // denied) surface directly. If the file does not exist, fall through + // and treat the positional as literal SQL — useful when the user + // passes a string that happens to end with ".sql". + if strings.HasSuffix(arg, sqlFileExtension) { + data, err := os.ReadFile(arg) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("read positional %q: %w", arg, err) + } + if err == nil { + sql := strings.TrimSpace(string(data)) + if sql == "" { + return nil, fmt.Errorf("positional %q is empty", arg) + } + units = append(units, inputUnit{SQL: sql, Source: arg}) + continue + } + } + sql := strings.TrimSpace(arg) + if sql == "" { + return nil, fmt.Errorf("argv[%d] is empty", i+1) + } + units = append(units, inputUnit{SQL: sql, Source: fmt.Sprintf("argv[%d]", i+1)}) + } + + if len(units) == 0 { + // No positionals, no --file: read from stdin if it's not a prompt- + // supporting TTY. The aitools query helper applies the same rule. + _, isOsFile := in.(*os.File) + if isOsFile && cmdio.IsPromptSupported(ctx) { + return nil, errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin") + } + data, err := io.ReadAll(in) + if err != nil { + return nil, fmt.Errorf("read stdin: %w", err) + } + sql := strings.TrimSpace(string(data)) + if sql == "" { + return nil, errors.New("no SQL provided") + } + units = append(units, inputUnit{SQL: sql, Source: "stdin"}) + } + + return units, nil +} diff --git a/experimental/postgres/cmd/inputs_test.go b/experimental/postgres/cmd/inputs_test.go new file mode 100644 index 0000000000..97d3d2abc7 --- /dev/null +++ b/experimental/postgres/cmd/inputs_test.go @@ -0,0 +1,101 @@ +package postgrescmd + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func writeTemp(t *testing.T, name, contents string) string { + t.Helper() + dir := t.TempDir() + p := filepath.Join(dir, name) + require.NoError(t, os.WriteFile(p, []byte(contents), 0o644)) + return p +} + +func TestCollectInputs_PositionalOnly(t *testing.T) { + units, err := collectInputs(t.Context(), strings.NewReader(""), []string{"SELECT 1"}, nil) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "SELECT 1", units[0].SQL) + assert.Equal(t, "argv[1]", units[0].Source) +} + +func TestCollectInputs_MultiplePositionals(t *testing.T) { + units, err := collectInputs(t.Context(), strings.NewReader(""), []string{"SELECT 1", "SELECT 2"}, nil) + require.NoError(t, err) + require.Len(t, units, 2) + assert.Equal(t, "SELECT 1", units[0].SQL) + assert.Equal(t, "SELECT 2", units[1].SQL) +} + +func TestCollectInputs_FileOnly(t *testing.T) { + p := writeTemp(t, "q.sql", "SELECT * FROM t") + units, err := collectInputs(t.Context(), strings.NewReader(""), nil, []string{p}) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "SELECT * FROM t", units[0].SQL) + assert.Contains(t, units[0].Source, "--file") +} + +func TestCollectInputs_FilesFirstThenPositionals(t *testing.T) { + p1 := writeTemp(t, "a.sql", "SELECT 1") + p2 := writeTemp(t, "b.sql", "SELECT 2") + units, err := collectInputs(t.Context(), strings.NewReader(""), []string{"SELECT 3"}, []string{p1, p2}) + require.NoError(t, err) + require.Len(t, units, 3) + assert.Equal(t, "SELECT 1", units[0].SQL) + assert.Equal(t, "SELECT 2", units[1].SQL) + assert.Equal(t, "SELECT 3", units[2].SQL) +} + +func TestCollectInputs_DotSQLAutoDetect(t *testing.T) { + p := writeTemp(t, "data.sql", "SELECT 42") + units, err := collectInputs(t.Context(), strings.NewReader(""), []string{p}, nil) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "SELECT 42", units[0].SQL) +} + +func TestCollectInputs_DotSQLNotExistingFallsThroughToLiteral(t *testing.T) { + units, err := collectInputs(t.Context(), strings.NewReader(""), []string{"/nonexistent/path.sql"}, nil) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "/nonexistent/path.sql", units[0].SQL) +} + +func TestCollectInputs_StdinOnly(t *testing.T) { + units, err := collectInputs(t.Context(), strings.NewReader("SELECT 1\n"), nil, nil) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "SELECT 1", units[0].SQL) + assert.Equal(t, "stdin", units[0].Source) +} + +func TestCollectInputs_StdinIgnoredWhenPositionalsPresent(t *testing.T) { + units, err := collectInputs(t.Context(), strings.NewReader("FROM STDIN"), []string{"SELECT 1"}, nil) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "SELECT 1", units[0].SQL) +} + +func TestCollectInputs_EmptyStdinErrors(t *testing.T) { + _, err := collectInputs(t.Context(), strings.NewReader(""), nil, nil) + assert.ErrorContains(t, err, "no SQL provided") +} + +func TestCollectInputs_EmptyFileErrors(t *testing.T) { + p := writeTemp(t, "empty.sql", "") + _, err := collectInputs(t.Context(), strings.NewReader(""), nil, []string{p}) + assert.ErrorContains(t, err, "is empty") +} + +func TestCollectInputs_EmptyPositional(t *testing.T) { + _, err := collectInputs(t.Context(), strings.NewReader(""), []string{" "}, nil) + assert.ErrorContains(t, err, "is empty") +} diff --git a/experimental/postgres/cmd/multistatement.go b/experimental/postgres/cmd/multistatement.go new file mode 100644 index 0000000000..4c4f976e8e --- /dev/null +++ b/experimental/postgres/cmd/multistatement.go @@ -0,0 +1,159 @@ +package postgrescmd + +import ( + "errors" + "strings" +) + +// errMultipleStatements is the typed error returned by checkSingleStatement +// when the input contains more than one ';'-separated statement. The runQuery +// path catches this with errors.Is to attach the multi-input workaround +// pointer in the user-visible message. +var errMultipleStatements = errors.New("input contains multiple statements (a ';' separates two or more statements)") + +// checkSingleStatement walks sql and returns errMultipleStatements if a +// statement-terminating ';' is found anywhere except trailing whitespace. +// +// The scanner ignores ';' inside: +// - single-quoted strings ('a;b', SQL standard doubled-quote escape) +// - double-quoted identifiers ("col;name") +// - line comments (-- ... \n) +// - block comments (/* ... */, non-nesting) +// - dollar-quoted bodies ($tag$ ... $tag$, optional tag) +// +// Over-rejection on weird syntactic edge cases is acceptable: users get a +// clear error and can split into multiple input units. v2 may swap this for +// a real Postgres tokenizer. +func checkSingleStatement(sql string) error { + s := sql + // Trim trailing whitespace once so a single trailing ';' is allowed. + end := len(strings.TrimRight(s, " \t\r\n")) + + i := 0 + for i < end { + c := s[i] + + switch c { + case ';': + // A ';' that's not at end-of-trimmed-input is a separator. + if i < end-1 { + return errMultipleStatements + } + // Trailing ';' is fine. + i++ + + case '\'': + // Single-quoted string. SQL standard escape is '' (doubled). + i = scanQuoted(s, i, end, '\'') + + case '"': + // Double-quoted identifier. Same '"' doubling escape rule. + i = scanQuoted(s, i, end, '"') + + case '-': + // Line comment "--" runs to next newline. + if i+1 < end && s[i+1] == '-' { + i = scanLineComment(s, i, end) + } else { + i++ + } + + case '/': + // Block comment "/* ... */". + if i+1 < end && s[i+1] == '*' { + i = scanBlockComment(s, i, end) + } else { + i++ + } + + case '$': + // Dollar-quoted body: $tag$ ... $tag$ (tag may be empty). + tag, end2 := readDollarTag(s, i, end) + if tag != "" || end2 > i { + i = scanDollarBody(s, end2, end, tag) + } else { + i++ + } + + default: + i++ + } + } + + return nil +} + +// scanQuoted advances past a quoted string or identifier opened at s[start] +// with the given quote character. SQL standard doubles the quote to escape +// (e.g. doubling the quote inside the string). Returns the index of the byte AFTER the closing quote, or +// end if the string is unterminated (over-permissive: an unterminated string +// at EOF means there's no ';' inside it anyway). +func scanQuoted(s string, start, end int, quote byte) int { + i := start + 1 + for i < end { + if s[i] == quote { + if i+1 < end && s[i+1] == quote { + i += 2 // doubled-quote escape + continue + } + return i + 1 + } + i++ + } + return end +} + +func scanLineComment(s string, start, end int) int { + i := start + 2 + for i < end && s[i] != '\n' { + i++ + } + return i +} + +func scanBlockComment(s string, start, end int) int { + i := start + 2 + for i+1 < end { + if s[i] == '*' && s[i+1] == '/' { + return i + 2 + } + i++ + } + return end +} + +// readDollarTag inspects s[start] (which must be '$') and returns the tag +// between the two dollar signs and the index right after the closing first +// '$' of $tag$. If the construct doesn't look like a valid dollar-quote +// opener, returns ("", start) so the caller can fall through. +// +// Tag rule: starts after '$', runs to the next '$', and must consist of +// letter-or-underscore-or-digit (we accept all non-special bytes; over- +// permissive). Empty tag is valid: $$ is a marker, $$body$$ is the body. +func readDollarTag(s string, start, end int) (string, int) { + i := start + 1 + for i < end { + if s[i] == '$' { + tag := s[start+1 : i] + return tag, i + 1 + } + // Stop at characters that can't be in a tag. + if s[i] == ' ' || s[i] == '\t' || s[i] == '\n' || s[i] == ';' { + return "", start + } + i++ + } + return "", start +} + +// scanDollarBody advances past a $tag$...$tag$ body starting at start (the +// byte right after the opening tag's closing '$'). Returns the index of the +// byte AFTER the closing tag, or end if unterminated. +func scanDollarBody(s string, start, end int, tag string) int { + close := "$" + tag + "$" + idx := strings.Index(s[start:end], close) + if idx < 0 { + return end + } + return start + idx + len(close) +} diff --git a/experimental/postgres/cmd/multistatement_test.go b/experimental/postgres/cmd/multistatement_test.go new file mode 100644 index 0000000000..bb50bf5e8e --- /dev/null +++ b/experimental/postgres/cmd/multistatement_test.go @@ -0,0 +1,54 @@ +package postgrescmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCheckSingleStatement(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {name: "single statement", input: "SELECT 1", wantErr: false}, + {name: "trailing semicolon allowed", input: "SELECT 1;", wantErr: false}, + {name: "trailing semicolon plus whitespace", input: "SELECT 1;\n ", wantErr: false}, + {name: "two statements rejected", input: "SELECT 1; SELECT 2", wantErr: true}, + {name: "two statements with trailing semi", input: "SELECT 1; SELECT 2;", wantErr: true}, + + {name: "semicolon in single-quoted string", input: "SELECT 'a;b'", wantErr: false}, + {name: "semicolon in double-quoted ident", input: `SELECT "col;name" FROM t`, wantErr: false}, + {name: "doubled quote escape", input: "SELECT 'it''s;ok'", wantErr: false}, + {name: "doubled identifier quote", input: `SELECT "x""y;z" FROM t`, wantErr: false}, + + {name: "semicolon in line comment", input: "SELECT 1 -- x;y\n", wantErr: false}, + {name: "semicolon in block comment", input: "SELECT 1 /* x;y */", wantErr: false}, + {name: "block comment unterminated", input: "SELECT 1 /* unterminated", wantErr: false}, + + {name: "semicolon in dollar body untagged", input: "SELECT $$a;b$$", wantErr: false}, + {name: "semicolon in dollar body tagged", input: "SELECT $tag$a;b$tag$", wantErr: false}, + {name: "create function with body", input: "CREATE FUNCTION f() RETURNS int AS $$ BEGIN; END $$ LANGUAGE plpgsql", wantErr: false}, + + {name: "semi inside string then real semi", input: "SELECT 'a;b'; SELECT 2", wantErr: true}, + {name: "semi inside line comment then real semi", input: "SELECT 1 -- ; \n; SELECT 2", wantErr: true}, + {name: "semi inside dollar then real semi", input: "SELECT $$a;b$$; SELECT 2", wantErr: true}, + + {name: "leading whitespace", input: " ;", wantErr: false}, + {name: "empty input", input: "", wantErr: false}, + {name: "only whitespace", input: " \n\t ", wantErr: false}, + {name: "only semicolon", input: ";", wantErr: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := checkSingleStatement(tc.input) + if tc.wantErr { + assert.ErrorIs(t, err, errMultipleStatements) + return + } + assert.NoError(t, err) + }) + } +} diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 2b4f12694f..bc339c29d4 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -2,10 +2,8 @@ package postgrescmd import ( "context" - "errors" "fmt" "io" - "strings" "time" "github.com/databricks/cli/cmd/root" @@ -26,6 +24,7 @@ type queryFlags struct { database string connectTimeout time.Duration maxRetries int + files []string // outputFormat is the raw flag value. resolveOutputFormat turns it into // the effective format (which may differ when stdout is piped). @@ -37,9 +36,9 @@ func newQueryCmd() *cobra.Command { var f queryFlags cmd := &cobra.Command{ - Use: "query [SQL]", - Short: "Run a SQL statement against a Lakebase Postgres endpoint", - Long: `Execute a single SQL statement against a Lakebase Postgres endpoint. + Use: "query [SQL | file.sql]...", + Short: "Run SQL statements against a Lakebase Postgres endpoint", + Long: `Execute one or more SQL statements against a Lakebase Postgres endpoint. Targeting (exactly one form required): --target STRING Provisioned instance name OR autoscaling resource path @@ -48,37 +47,43 @@ Targeting (exactly one form required): --branch ID Autoscaling branch ID (default: auto-select if exactly one) --endpoint ID Autoscaling endpoint ID +Inputs (positionals and --file may be combined; execution order is files-first +then positionals; stdin is used only when neither is present): + -f, --file PATH SQL file path (repeatable). Each file must contain + exactly one statement. + positional SQL string OR path ending in '.sql' that exists on disk. + Output: --output text Aligned table for rows-producing statements (default). Falls back to JSON when stdout is not a terminal so scripts piping the output get machine-readable results. - --output json Top-level array of row objects, streamed for - rows-producing statements. Command-only statements - emit a single {"command": "...", "rows_affected": N} - object. Numbers, booleans, NULL, jsonb, timestamps - render with their JSON-native types. + --output json For a single input: top-level array of row objects, + streamed. For multiple inputs: top-level array of + per-unit result objects ({"sql","kind","elapsed_ms",...}), + with each object buffered to completion. --output csv Header row + one CSV row per result row, streamed. - Command-only statements write the command tag to - stderr. + Single-input only; multi-input + csv is rejected + pre-flight. Use --output json for multi-input. DATABRICKS_OUTPUT_FORMAT is honoured when --output is not explicitly set. -This is an experimental command. The flag set, output shape, and supported -target kinds will expand in subsequent releases. - Limitations (this release): - - Single SQL statement per invocation (multi-statement support comes later). + - Single statement per input unit. Multi-statement strings (e.g. + "SELECT 1; SELECT 2") are rejected; pass each as a separate positional + or --file. - No interactive REPL. 'databricks psql' continues to own that surface. - - Multi-statement strings (e.g. "SELECT 1; SELECT 2") are not supported. + - Inputs run sequentially on one connection; session state (SET, temp + tables, prepared statement names) carries across them. - The OAuth token is generated once per invocation and is valid for 1h. Queries longer than that fail with an auth error. + - --output csv is rejected when more than one input unit is present; + use --output json or split into separate invocations. `, - Args: cobra.ExactArgs(1), PreRunE: root.MustWorkspaceClient, RunE: func(cmd *cobra.Command, args []string) error { f.outputFormatSet = cmd.Flag("output").Changed - return runQuery(cmd.Context(), cmd, args[0], f) + return runQuery(cmd.Context(), cmd, args, f) }, } @@ -89,6 +94,7 @@ Limitations (this release): cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)") + cmd.Flags().StringArrayVarP(&f.files, "file", "f", nil, "SQL file path (repeatable)") cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(outputText), "Output format: text, json, or csv") cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { out := make([]string, len(allOutputFormats)) @@ -108,11 +114,7 @@ Limitations (this release): // runQuery is the production entry point. It is split out from RunE so unit // tests can call it directly with a stubbed connectFunc once we add seam-based // tests in a later PR. -func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) error { - sql = strings.TrimSpace(sql) - if sql == "" { - return errors.New("no SQL provided") - } +func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFlags) error { if f.maxRetries < 1 { return fmt.Errorf("--max-retries must be at least 1; got %d", f.maxRetries) } @@ -120,17 +122,30 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) return err } - // SupportsColor is the public TTY-ish signal libs/cmdio exposes today; it - // also folds in NO_COLOR / TERM=dumb, which strictly speaking are colour - // preferences rather than TTY signals. Users who hit that edge case can - // pass --output text explicitly; that path is honoured (see - // resolveOutputFormat). Mirrors the aitools query command. + units, err := collectInputs(ctx, cmd.InOrStdin(), args, f.files) + if err != nil { + return err + } + for _, u := range units { + if err := checkSingleStatement(u.SQL); err != nil { + return fmt.Errorf("%s: %w%s", u.Source, err, multiStatementHint()) + } + } + stdoutTTY := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) format, err := resolveOutputFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) if err != nil { return err } + // CSV multi-input is rejected pre-flight: there is no sensible shape for + // a CSV that has to merge schemas across statements. The error names the + // flag pair and tells the user how to recover, per the repo rule about + // rejecting incompatible inputs early. + if format == outputCSV && len(units) > 1 { + return fmt.Errorf("--output csv requires a single input unit; got %d (use --output json for multi-input invocations)", len(units)) + } + resolved, err := resolveTarget(ctx, f.targetingFlags) if err != nil { return err @@ -162,8 +177,43 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) } defer conn.Close(context.WithoutCancel(ctx)) - sink := newSink(format, cmd.OutOrStdout(), cmd.ErrOrStderr()) - return executeOne(ctx, conn, sql, sink) + out := cmd.OutOrStdout() + stderr := cmd.ErrOrStderr() + + if len(units) == 1 { + // Single-input path: stream directly through the per-format sink. + // Avoids buffering rows for large exports and matches the v1 single- + // input behaviour PR 2 shipped. + sink := newSink(format, out, stderr) + return executeOne(ctx, conn, units[0].SQL, sink) + } + + // Multi-input path: per-unit buffering. The plan accepts this trade-off + // (multi-input invocations with huge SELECTs should use single-input + // invocations with --output csv for streaming). Sessions state (SET, + // temp tables) carries across units because we hold the same connection. + results := make([]*unitResult, 0, len(units)) + for _, u := range units { + r, err := runUnitBuffered(ctx, conn, u) + if err != nil { + // Render the successful prefix, then surface the error with + // rich pgError formatting if applicable. + if rerr := renderPartial(out, stderr, format, results, u, err); rerr != nil { + // Best-effort partial render failed; surface the original + // error to the user, the renderer error to debug logs. + fmt.Fprintln(stderr, "warning: failed to render partial result:", rerr) + } + return formatExecutionError(u.Source, err) + } + results = append(results, r) + } + + switch format { + case outputJSON: + return renderJSONMulti(out, stderr, results, -1, "") + default: + return renderTextMulti(out, results) + } } // newSink returns the rowSink for the chosen output format. Kept separate @@ -178,3 +228,38 @@ func newSink(format outputFormat, out, stderr io.Writer) rowSink { return newTextSink(out) } } + +// renderPartial emits the rendered output for the prefix of units that ran +// successfully before a unit errored. For multi-input json this also writes +// the error envelope as the last array element. +func renderPartial(out, stderr io.Writer, format outputFormat, results []*unitResult, errored inputUnit, err error) error { + switch format { + case outputJSON: + return renderJSONMulti(out, stderr, results, len(results), formatExecutionErrorMessage(errored.Source, err)) + default: + // Text: render whatever ran cleanly. The error message goes through + // cobra's default error path on stderr after we return. + return renderTextMulti(out, results) + } +} + +// formatExecutionError produces the error returned to cobra when an input +// unit failed. The message includes the source label so the user knows +// which of N inputs blew up. +func formatExecutionError(source string, err error) error { + return fmt.Errorf("%s: %s", source, formatPgError(err)) +} + +// formatExecutionErrorMessage is the string form of formatExecutionError, +// suitable for embedding in JSON envelopes. +func formatExecutionErrorMessage(source string, err error) string { + return fmt.Sprintf("%s: %s", source, formatPgError(err)) +} + +// multiStatementHint is the workaround pointer appended to the +// errMultipleStatements error so users see the recovery path inline. +func multiStatementHint() string { + return "\nThis command runs one statement per input. To run multiple statements:\n" + + ` - Pass each as a separate positional: query "SELECT 1" "SELECT 2"` + "\n" + + ` - Pass each in its own --file: query --file q1.sql --file q2.sql` +} diff --git a/experimental/postgres/cmd/render_multi.go b/experimental/postgres/cmd/render_multi.go new file mode 100644 index 0000000000..6ffc1e8d58 --- /dev/null +++ b/experimental/postgres/cmd/render_multi.go @@ -0,0 +1,209 @@ +package postgrescmd + +import ( + "bytes" + "fmt" + "io" + "strings" +) + +// renderTextMulti renders a sequence of unit results as plain text. Each +// per-unit block follows the single-input layout (table for rows-producing, +// command tag for command-only); successive blocks are separated by a blank +// line, mirroring `psql -c "...; ..."` shape. +// +// errIndex/errResult identifies the unit that errored (-1 if none); we still +// render any successful prefix. The error itself is surfaced by the caller +// via cobra's default error rendering. +func renderTextMulti(out io.Writer, results []*unitResult) error { + for i, r := range results { + if i > 0 { + if _, err := io.WriteString(out, "\n"); err != nil { + return err + } + } + if err := renderTextResult(out, r); err != nil { + return err + } + } + return nil +} + +// renderTextResult renders a single buffered unitResult in the same shape as +// textSink would for a streamed result. +func renderTextResult(out io.Writer, r *unitResult) error { + if !r.IsRowsProducing() { + _, err := fmt.Fprintln(out, r.CommandTag) + return err + } + + // Reuse textSink for the table layout so single-input and multi-input + // share the same alignment and footer logic. + sink := newTextSink(out) + if err := sink.Begin(r.Fields); err != nil { + return err + } + for _, row := range r.Rows { + if err := sink.Row(row); err != nil { + return err + } + } + return sink.End(r.CommandTag) +} + +// renderJSONMulti emits the wrapped multi-input JSON shape: a top-level +// array of result objects, one per input unit. Per-unit objects are buffered +// to completion before write; the outer array uses separator-before-element +// streaming. CSV multi-input is rejected pre-flight, so this function is +// only used for json. +// +// Per-unit shape: +// +// {"sql": "...", "kind": "rows", "elapsed_ms": N, "rows": [...]} +// {"sql": "...", "kind": "command", "elapsed_ms": N, "command": "...", "rows_affected": N} +// {"sql": "...", "kind": "error", "elapsed_ms": N, "error": {...}} +// +// kind discriminates which fields are present so consumers don't have to +// branch on key presence. +func renderJSONMulti(out, stderr io.Writer, results []*unitResult, errIndex int, errMessage string) error { + if _, err := io.WriteString(out, "[\n"); err != nil { + return err + } + + for i, r := range results { + if i > 0 { + if _, err := io.WriteString(out, ",\n"); err != nil { + return err + } + } + var unitBuf bytes.Buffer + if err := renderJSONUnit(&unitBuf, stderr, r); err != nil { + return err + } + if _, err := out.Write(unitBuf.Bytes()); err != nil { + return err + } + } + + if errIndex >= 0 { + // The errored unit follows the last successful unit; write a comma + // separator and the error envelope for it. + if len(results) > 0 { + if _, err := io.WriteString(out, ",\n"); err != nil { + return err + } + } + errSQL := "" + errSource := "" + // errIndex points to the input *unit* index; since we render + // successful units in order, the errored unit's SQL came from the + // caller's units slice. The caller embeds it in errMessage so we + // don't need separate plumbing here. + obj := jsonErrorObject(errSource, errSQL, errMessage) + if _, err := out.Write(obj); err != nil { + return err + } + } + + _, err := io.WriteString(out, "\n]\n") + return err +} + +// renderJSONUnit writes one buffered result object to buf, using the +// existing single-input json rendering for the rows array so the value +// mapping stays consistent across single- and multi-input shapes. +func renderJSONUnit(buf *bytes.Buffer, stderr io.Writer, r *unitResult) error { + if !r.IsRowsProducing() { + // Command-only unit. + if _, err := fmt.Fprintf(buf, `{"sql":`); err != nil { + return err + } + sqlJSON, err := marshalJSON(r.SQL) + if err != nil { + return err + } + buf.Write(sqlJSON) + fmt.Fprintf(buf, `,"kind":"command","elapsed_ms":%d`, r.Elapsed.Milliseconds()) + fmt.Fprintf(buf, `,"command":"%s"`, jsonEscapeShort(commandTagVerb(r.CommandTag))) + if rows, ok := commandTagRowCount(r.CommandTag); ok { + fmt.Fprintf(buf, `,"rows_affected":%d`, rows) + } + buf.WriteString(`}`) + return nil + } + + // Rows-producing unit. We reuse jsonSink for the rows array body so + // the per-row encoding (column order, type mapping) stays in one place. + if _, err := fmt.Fprintf(buf, `{"sql":`); err != nil { + return err + } + sqlJSON, err := marshalJSON(r.SQL) + if err != nil { + return err + } + buf.Write(sqlJSON) + fmt.Fprintf(buf, `,"kind":"rows","elapsed_ms":%d,"rows":`, r.Elapsed.Milliseconds()) + + rowsBuf := &bytes.Buffer{} + sink := newJSONSink(rowsBuf, stderr) + if err := sink.Begin(r.Fields); err != nil { + return err + } + for _, row := range r.Rows { + if err := sink.Row(row); err != nil { + return err + } + } + // Use a no-op tag for End so jsonSink's success path emits the closing + // bracket. The trailing newline gets trimmed below. + if err := sink.End(""); err != nil { + return err + } + rowsTrimmed := bytes.TrimRight(rowsBuf.Bytes(), "\n") + buf.Write(rowsTrimmed) + buf.WriteString(`}`) + return nil +} + +// jsonErrorObject builds the per-unit error envelope used in the multi-input +// JSON shape. message is the formatted error message (already includes +// SQLSTATE / hint / detail when applicable). +func jsonErrorObject(source, sql, message string) []byte { + var buf bytes.Buffer + buf.WriteString(`{"source":`) + if b, err := marshalJSON(source); err == nil { + buf.Write(b) + } else { + buf.WriteString(`""`) + } + buf.WriteString(`,"sql":`) + if b, err := marshalJSON(sql); err == nil { + buf.Write(b) + } else { + buf.WriteString(`""`) + } + buf.WriteString(`,"kind":"error","error":{"message":`) + if b, err := marshalJSON(message); err == nil { + buf.Write(b) + } else { + buf.WriteString(`""`) + } + buf.WriteString(`}}`) + return buf.Bytes() +} + +// jsonEscapeShort is a fast path for short ASCII strings (command tag verbs) +// that need backslash escapes for ", \, and control bytes only. Falls back +// to a string-escaped value if the input contains anything unusual. +func jsonEscapeShort(s string) string { + if !strings.ContainsAny(s, "\"\\\n\r\t") { + return s + } + out, err := marshalJSON(s) + if err != nil { + return s + } + // marshalJSON returns the value with surrounding quotes; strip them so + // the caller can wrap with its own quoting. + return string(bytes.Trim(out, `"`)) +} diff --git a/experimental/postgres/cmd/render_multi_test.go b/experimental/postgres/cmd/render_multi_test.go new file mode 100644 index 0000000000..dba5174a43 --- /dev/null +++ b/experimental/postgres/cmd/render_multi_test.go @@ -0,0 +1,89 @@ +package postgrescmd + +import ( + "bytes" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRenderTextMulti_TwoResults(t *testing.T) { + r1 := &unitResult{ + Source: "argv[1]", + SQL: "INSERT INTO t VALUES (1)", + CommandTag: "INSERT 0 1", + Elapsed: 5 * time.Millisecond, + } + r2 := &unitResult{ + Source: "argv[2]", + SQL: "SELECT id FROM t", + Fields: fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}), + Rows: [][]any{{int64(1)}}, + CommandTag: "SELECT 1", + Elapsed: 3 * time.Millisecond, + } + + var buf bytes.Buffer + require.NoError(t, renderTextMulti(&buf, []*unitResult{r1, r2})) + out := buf.String() + assert.Contains(t, out, "INSERT 0 1\n") + assert.Contains(t, out, "id") + assert.Contains(t, out, "(1 row)") + // Blank-line separator between blocks. + assert.Contains(t, out, "INSERT 0 1\n\n") +} + +func TestRenderJSONMulti_TwoResults(t *testing.T) { + r1 := &unitResult{ + Source: "argv[1]", + SQL: "INSERT INTO t VALUES (1)", + CommandTag: "INSERT 0 1", + Elapsed: 5 * time.Millisecond, + } + r2 := &unitResult{ + Source: "argv[2]", + SQL: "SELECT id FROM t", + Fields: fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}), + Rows: [][]any{{int64(1)}, {int64(2)}}, + CommandTag: "SELECT 2", + Elapsed: 3 * time.Millisecond, + } + + var stdout, stderr bytes.Buffer + require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1, r2}, -1, "")) + + out := stdout.String() + assert.Contains(t, out, `"sql":"INSERT INTO t VALUES (1)"`) + assert.Contains(t, out, `"kind":"command"`) + assert.Contains(t, out, `"command":"INSERT"`) + assert.Contains(t, out, `"rows_affected":1`) + assert.Contains(t, out, `"sql":"SELECT id FROM t"`) + assert.Contains(t, out, `"kind":"rows"`) + assert.Contains(t, out, `"rows":`) + // Outer array framing. + assert.Greater(t, len(out), 4) + assert.Equal(t, byte('['), out[0]) + assert.Equal(t, byte('\n'), out[len(out)-1]) +} + +func TestRenderJSONMulti_WithErrorAtEnd(t *testing.T) { + r1 := &unitResult{ + Source: "argv[1]", + SQL: "SELECT 1", + Fields: fieldsWithOIDs([]string{"?column?"}, []uint32{pgtype.Int8OID}), + Rows: [][]any{{int64(1)}}, + CommandTag: "SELECT 1", + Elapsed: 1 * time.Millisecond, + } + + var stdout, stderr bytes.Buffer + require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1}, 1, "argv[2]: ERROR: syntax error (SQLSTATE 42601)")) + + out := stdout.String() + assert.Contains(t, out, `"kind":"rows"`) + assert.Contains(t, out, `"kind":"error"`) + assert.Contains(t, out, `"message":"argv[2]: ERROR: syntax error (SQLSTATE 42601)"`) +} diff --git a/experimental/postgres/cmd/result.go b/experimental/postgres/cmd/result.go new file mode 100644 index 0000000000..d9b449a484 --- /dev/null +++ b/experimental/postgres/cmd/result.go @@ -0,0 +1,62 @@ +package postgrescmd + +import ( + "context" + "fmt" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// unitResult is the buffered result of running one input unit. The +// multi-input renderers (text, json) need rows buffered before they can +// emit a per-unit block; for the single-input path we still stream +// directly through a rowSink and never produce a unitResult. +type unitResult struct { + Source string + SQL string + Fields []pgconn.FieldDescription + Rows [][]any + CommandTag string + Elapsed time.Duration +} + +// IsRowsProducing returns whether the unit returned a row description. +func (r *unitResult) IsRowsProducing() bool { + return len(r.Fields) > 0 +} + +// runUnitBuffered runs sql and collects every row into memory. Used by the +// multi-input output paths (text and json), where per-unit buffering is +// acceptable per the plan: a multi-input invocation that emits a huge +// SELECT will buffer that result before printing. Users with huge result +// sets per statement should use single-input invocations (which fully +// stream) or --output csv on a single input. +func runUnitBuffered(ctx context.Context, conn *pgx.Conn, unit inputUnit) (*unitResult, error) { + start := time.Now() + rows, err := conn.Query(ctx, unit.SQL, pgx.QueryExecModeExec) + if err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + defer rows.Close() + + r := &unitResult{ + Source: unit.Source, + SQL: unit.SQL, + Fields: rows.FieldDescriptions(), + } + for rows.Next() { + values, err := rows.Values() + if err != nil { + return nil, fmt.Errorf("decode row: %w", err) + } + r.Rows = append(r.Rows, values) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + r.CommandTag = rows.CommandTag().String() + r.Elapsed = time.Since(start) + return r, nil +} From 78357bdb760b664ba98592d209d8d0aec34efc81 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:41:57 +0200 Subject: [PATCH 2/3] Address PR 3 review feedback round 1 MUSTs: - Multi-input JSON error envelope: thread the failing *unitResult into renderJSONMulti so source/sql/elapsed_ms reflect the actual failing input instead of empty strings. - Canonical key order for every per-unit object: {"source", "sql", "kind", "elapsed_ms", payload} Success and error envelopes now share the same shape so consumers don't have to special-case kind=="error" for missing fields. SHOULDs: - Single-input path now goes through formatPgError, so DETAIL/HINT surface consistently across single- and multi-input. - runUnitBuffered reuses executeOne via a new bufferSink. The two query loops collapse to one; future error-handling changes auto- propagate. - Scanner: reject `$...` as a dollar-quote tag (PG docs forbid digit-leading tags). Pinned with a test for `SELECT $1, $2 FROM t` and `SELECT $1 FROM t; SELECT 2`. - Pin the E-string over-rejection behaviour with a test, so a future scanner improvement has to update the assertion. CONSIDERs: - Capture elapsed_ms on the error path too (was previously discarded). - Promote multiStatementHint to a const. - Drop jsonEscapeShort (was a fragile micro-opt for an always-ASCII domain); use marshalJSON for the command verb instead. - Add TestRenderJSONMulti_FirstUnitFails to pin the empty-success- prefix framing. Co-authored-by: Isaac --- experimental/postgres/cmd/multistatement.go | 13 +- .../postgres/cmd/multistatement_test.go | 12 ++ experimental/postgres/cmd/query.go | 39 +++--- experimental/postgres/cmd/render_multi.go | 125 ++++++++---------- .../postgres/cmd/render_multi_test.go | 38 ++++-- experimental/postgres/cmd/result.go | 63 +++++---- 6 files changed, 157 insertions(+), 133 deletions(-) diff --git a/experimental/postgres/cmd/multistatement.go b/experimental/postgres/cmd/multistatement.go index 4c4f976e8e..4bfedbbfab 100644 --- a/experimental/postgres/cmd/multistatement.go +++ b/experimental/postgres/cmd/multistatement.go @@ -127,9 +127,12 @@ func scanBlockComment(s string, start, end int) int { // '$' of $tag$. If the construct doesn't look like a valid dollar-quote // opener, returns ("", start) so the caller can fall through. // -// Tag rule: starts after '$', runs to the next '$', and must consist of -// letter-or-underscore-or-digit (we accept all non-special bytes; over- -// permissive). Empty tag is valid: $$ is a marker, $$body$$ is the body. +// Tag rule: starts after '$', runs to the next '$'. Per the Postgres docs a +// dollar-quote tag must not start with a digit, so we reject `$1`, `$2`, +// etc. as tags and let the scanner treat them as ordinary bytes (this is +// what `$1`-style parameter placeholders look like, even though `QueryExecModeExec` +// can't bind them in this command). Empty tag is valid: $$ is a marker, +// $$body$$ is the body. func readDollarTag(s string, start, end int) (string, int) { i := start + 1 for i < end { @@ -137,6 +140,10 @@ func readDollarTag(s string, start, end int) (string, int) { tag := s[start+1 : i] return tag, i + 1 } + // Reject `$...` early: it can't be a valid tag. + if i == start+1 && s[i] >= '0' && s[i] <= '9' { + return "", start + } // Stop at characters that can't be in a tag. if s[i] == ' ' || s[i] == '\t' || s[i] == '\n' || s[i] == ';' { return "", start diff --git a/experimental/postgres/cmd/multistatement_test.go b/experimental/postgres/cmd/multistatement_test.go index bb50bf5e8e..ae60ee7e15 100644 --- a/experimental/postgres/cmd/multistatement_test.go +++ b/experimental/postgres/cmd/multistatement_test.go @@ -39,6 +39,18 @@ func TestCheckSingleStatement(t *testing.T) { {name: "empty input", input: "", wantErr: false}, {name: "only whitespace", input: " \n\t ", wantErr: false}, {name: "only semicolon", input: ";", wantErr: false}, + + // $1 / $2 placeholder syntax must not be confused with a dollar-quote + // tag (tags can't start with a digit per PG docs). + {name: "dollar-digit placeholders", input: "SELECT $1, $2 FROM t", wantErr: false}, + {name: "dollar-digit then real semi", input: "SELECT $1 FROM t; SELECT 2", wantErr: true}, + + // E-string escape syntax: scanner doesn't honour \' escape, so a + // backslash-escaped apostrophe terminates the literal early. We + // document the over-rejection rather than fix it (acceptable v1 + // stance per the plan); pin the behaviour here so the next person + // touching the scanner has to update the test. + {name: "E-string with backslash-escape over-rejects", input: `SELECT E'foo\';bar'`, wantErr: true}, } for _, tc := range tests { diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index bc339c29d4..4bd75c8a71 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -2,6 +2,7 @@ package postgrescmd import ( "context" + "errors" "fmt" "io" "time" @@ -128,7 +129,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla } for _, u := range units { if err := checkSingleStatement(u.SQL); err != nil { - return fmt.Errorf("%s: %w%s", u.Source, err, multiStatementHint()) + return fmt.Errorf("%s: %w%s", u.Source, err, multiStatementHint) } } @@ -183,14 +184,18 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla if len(units) == 1 { // Single-input path: stream directly through the per-format sink. // Avoids buffering rows for large exports and matches the v1 single- - // input behaviour PR 2 shipped. + // input behaviour PR 2 shipped. Wrap the error so DETAIL / HINT + // from a *pgconn.PgError surface even on the single-input path. sink := newSink(format, out, stderr) - return executeOne(ctx, conn, units[0].SQL, sink) + if err := executeOne(ctx, conn, units[0].SQL, sink); err != nil { + return errors.New(formatPgError(err)) + } + return nil } // Multi-input path: per-unit buffering. The plan accepts this trade-off // (multi-input invocations with huge SELECTs should use single-input - // invocations with --output csv for streaming). Sessions state (SET, + // invocations with --output csv for streaming). Session state (SET, // temp tables) carries across units because we hold the same connection. results := make([]*unitResult, 0, len(units)) for _, u := range units { @@ -198,7 +203,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla if err != nil { // Render the successful prefix, then surface the error with // rich pgError formatting if applicable. - if rerr := renderPartial(out, stderr, format, results, u, err); rerr != nil { + if rerr := renderPartial(out, stderr, format, results, r, err); rerr != nil { // Best-effort partial render failed; surface the original // error to the user, the renderer error to debug logs. fmt.Fprintln(stderr, "warning: failed to render partial result:", rerr) @@ -210,7 +215,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla switch format { case outputJSON: - return renderJSONMulti(out, stderr, results, -1, "") + return renderJSONMulti(out, stderr, results, nil, "") default: return renderTextMulti(out, results) } @@ -232,10 +237,10 @@ func newSink(format outputFormat, out, stderr io.Writer) rowSink { // renderPartial emits the rendered output for the prefix of units that ran // successfully before a unit errored. For multi-input json this also writes // the error envelope as the last array element. -func renderPartial(out, stderr io.Writer, format outputFormat, results []*unitResult, errored inputUnit, err error) error { +func renderPartial(out, stderr io.Writer, format outputFormat, results []*unitResult, errored *unitResult, err error) error { switch format { case outputJSON: - return renderJSONMulti(out, stderr, results, len(results), formatExecutionErrorMessage(errored.Source, err)) + return renderJSONMulti(out, stderr, results, errored, formatPgError(err)) default: // Text: render whatever ran cleanly. The error message goes through // cobra's default error path on stderr after we return. @@ -250,16 +255,8 @@ func formatExecutionError(source string, err error) error { return fmt.Errorf("%s: %s", source, formatPgError(err)) } -// formatExecutionErrorMessage is the string form of formatExecutionError, -// suitable for embedding in JSON envelopes. -func formatExecutionErrorMessage(source string, err error) string { - return fmt.Sprintf("%s: %s", source, formatPgError(err)) -} - -// multiStatementHint is the workaround pointer appended to the -// errMultipleStatements error so users see the recovery path inline. -func multiStatementHint() string { - return "\nThis command runs one statement per input. To run multiple statements:\n" + - ` - Pass each as a separate positional: query "SELECT 1" "SELECT 2"` + "\n" + - ` - Pass each in its own --file: query --file q1.sql --file q2.sql` -} +// multiStatementHint is appended to errMultipleStatements so users see the +// recovery path inline. +const multiStatementHint = "\nThis command runs one statement per input. To run multiple statements:\n" + + ` - Pass each as a separate positional: query "SELECT 1" "SELECT 2"` + "\n" + + ` - Pass each in its own --file: query --file q1.sql --file q2.sql` diff --git a/experimental/postgres/cmd/render_multi.go b/experimental/postgres/cmd/render_multi.go index 6ffc1e8d58..4cfa2063f7 100644 --- a/experimental/postgres/cmd/render_multi.go +++ b/experimental/postgres/cmd/render_multi.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "strings" ) // renderTextMulti renders a sequence of unit results as plain text. Each @@ -57,15 +56,19 @@ func renderTextResult(out io.Writer, r *unitResult) error { // streaming. CSV multi-input is rejected pre-flight, so this function is // only used for json. // -// Per-unit shape: +// Every per-unit object shares the same canonical key order: // -// {"sql": "...", "kind": "rows", "elapsed_ms": N, "rows": [...]} -// {"sql": "...", "kind": "command", "elapsed_ms": N, "command": "...", "rows_affected": N} -// {"sql": "...", "kind": "error", "elapsed_ms": N, "error": {...}} +// {"source", "sql", "kind", "elapsed_ms", payload...} // -// kind discriminates which fields are present so consumers don't have to -// branch on key presence. -func renderJSONMulti(out, stderr io.Writer, results []*unitResult, errIndex int, errMessage string) error { +// where payload depends on kind: +// +// "rows": {..., "rows": [...]} +// "command": {..., "command": "...", "rows_affected": N} +// "error": {..., "error": {"message": "..."}} +// +// elapsed_ms is present on errors too: it captures how long the failing +// statement ran before the error fired. +func renderJSONMulti(out, stderr io.Writer, results []*unitResult, errored *unitResult, errMessage string) error { if _, err := io.WriteString(out, "[\n"); err != nil { return err } @@ -85,21 +88,13 @@ func renderJSONMulti(out, stderr io.Writer, results []*unitResult, errIndex int, } } - if errIndex >= 0 { - // The errored unit follows the last successful unit; write a comma - // separator and the error envelope for it. + if errored != nil { if len(results) > 0 { if _, err := io.WriteString(out, ",\n"); err != nil { return err } } - errSQL := "" - errSource := "" - // errIndex points to the input *unit* index; since we render - // successful units in order, the errored unit's SQL came from the - // caller's units slice. The caller embeds it in errMessage so we - // don't need separate plumbing here. - obj := jsonErrorObject(errSource, errSQL, errMessage) + obj := jsonErrorObject(errored, errMessage) if _, err := out.Write(obj); err != nil { return err } @@ -113,18 +108,19 @@ func renderJSONMulti(out, stderr io.Writer, results []*unitResult, errIndex int, // existing single-input json rendering for the rows array so the value // mapping stays consistent across single- and multi-input shapes. func renderJSONUnit(buf *bytes.Buffer, stderr io.Writer, r *unitResult) error { + if err := writeJSONUnitHeader(buf, r); err != nil { + return err + } + if !r.IsRowsProducing() { - // Command-only unit. - if _, err := fmt.Fprintf(buf, `{"sql":`); err != nil { - return err - } - sqlJSON, err := marshalJSON(r.SQL) + buf.WriteString(`,"kind":"command"`) + fmt.Fprintf(buf, `,"elapsed_ms":%d`, r.Elapsed.Milliseconds()) + verbBytes, err := marshalJSON(commandTagVerb(r.CommandTag)) if err != nil { return err } - buf.Write(sqlJSON) - fmt.Fprintf(buf, `,"kind":"command","elapsed_ms":%d`, r.Elapsed.Milliseconds()) - fmt.Fprintf(buf, `,"command":"%s"`, jsonEscapeShort(commandTagVerb(r.CommandTag))) + buf.WriteString(`,"command":`) + buf.Write(verbBytes) if rows, ok := commandTagRowCount(r.CommandTag); ok { fmt.Fprintf(buf, `,"rows_affected":%d`, rows) } @@ -132,17 +128,10 @@ func renderJSONUnit(buf *bytes.Buffer, stderr io.Writer, r *unitResult) error { return nil } - // Rows-producing unit. We reuse jsonSink for the rows array body so - // the per-row encoding (column order, type mapping) stays in one place. - if _, err := fmt.Fprintf(buf, `{"sql":`); err != nil { - return err - } - sqlJSON, err := marshalJSON(r.SQL) - if err != nil { - return err - } - buf.Write(sqlJSON) - fmt.Fprintf(buf, `,"kind":"rows","elapsed_ms":%d,"rows":`, r.Elapsed.Milliseconds()) + // Rows-producing unit. Reuse jsonSink for the rows array body so the + // per-row encoding (column order, type mapping) stays in one place. + buf.WriteString(`,"kind":"rows"`) + fmt.Fprintf(buf, `,"elapsed_ms":%d,"rows":`, r.Elapsed.Milliseconds()) rowsBuf := &bytes.Buffer{} sink := newJSONSink(rowsBuf, stderr) @@ -154,8 +143,6 @@ func renderJSONUnit(buf *bytes.Buffer, stderr io.Writer, r *unitResult) error { return err } } - // Use a no-op tag for End so jsonSink's success path emits the closing - // bracket. The trailing newline gets trimmed below. if err := sink.End(""); err != nil { return err } @@ -165,24 +152,38 @@ func renderJSONUnit(buf *bytes.Buffer, stderr io.Writer, r *unitResult) error { return nil } -// jsonErrorObject builds the per-unit error envelope used in the multi-input -// JSON shape. message is the formatted error message (already includes -// SQLSTATE / hint / detail when applicable). -func jsonErrorObject(source, sql, message string) []byte { - var buf bytes.Buffer - buf.WriteString(`{"source":`) - if b, err := marshalJSON(source); err == nil { - buf.Write(b) - } else { - buf.WriteString(`""`) +// writeJSONUnitHeader writes the canonical {source, sql, ...} prefix used +// by every per-unit object. The closing brace and the kind-specific payload +// are appended by the caller. +func writeJSONUnitHeader(buf *bytes.Buffer, r *unitResult) error { + sourceBytes, err := marshalJSON(r.Source) + if err != nil { + return err + } + sqlBytes, err := marshalJSON(r.SQL) + if err != nil { + return err } + buf.WriteString(`{"source":`) + buf.Write(sourceBytes) buf.WriteString(`,"sql":`) - if b, err := marshalJSON(sql); err == nil { - buf.Write(b) - } else { - buf.WriteString(`""`) + buf.Write(sqlBytes) + return nil +} + +// jsonErrorObject builds the per-unit error envelope used in the multi-input +// JSON shape. The buffered unitResult provides source, SQL, and the elapsed +// time captured by runUnitBuffered's error path. message is the +// already-formatted error wording (includes SQLSTATE / hint / detail for +// PgErrors). +func jsonErrorObject(r *unitResult, message string) []byte { + var buf bytes.Buffer + if err := writeJSONUnitHeader(&buf, r); err != nil { + return []byte(`{"source":"","sql":"","kind":"error","elapsed_ms":0,"error":{"message":""}}`) } - buf.WriteString(`,"kind":"error","error":{"message":`) + buf.WriteString(`,"kind":"error"`) + fmt.Fprintf(&buf, `,"elapsed_ms":%d`, r.Elapsed.Milliseconds()) + buf.WriteString(`,"error":{"message":`) if b, err := marshalJSON(message); err == nil { buf.Write(b) } else { @@ -191,19 +192,3 @@ func jsonErrorObject(source, sql, message string) []byte { buf.WriteString(`}}`) return buf.Bytes() } - -// jsonEscapeShort is a fast path for short ASCII strings (command tag verbs) -// that need backslash escapes for ", \, and control bytes only. Falls back -// to a string-escaped value if the input contains anything unusual. -func jsonEscapeShort(s string) string { - if !strings.ContainsAny(s, "\"\\\n\r\t") { - return s - } - out, err := marshalJSON(s) - if err != nil { - return s - } - // marshalJSON returns the value with surrounding quotes; strip them so - // the caller can wrap with its own quoting. - return string(bytes.Trim(out, `"`)) -} diff --git a/experimental/postgres/cmd/render_multi_test.go b/experimental/postgres/cmd/render_multi_test.go index dba5174a43..b4e96f73eb 100644 --- a/experimental/postgres/cmd/render_multi_test.go +++ b/experimental/postgres/cmd/render_multi_test.go @@ -53,16 +53,12 @@ func TestRenderJSONMulti_TwoResults(t *testing.T) { } var stdout, stderr bytes.Buffer - require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1, r2}, -1, "")) + require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1, r2}, nil, "")) out := stdout.String() - assert.Contains(t, out, `"sql":"INSERT INTO t VALUES (1)"`) - assert.Contains(t, out, `"kind":"command"`) - assert.Contains(t, out, `"command":"INSERT"`) - assert.Contains(t, out, `"rows_affected":1`) - assert.Contains(t, out, `"sql":"SELECT id FROM t"`) - assert.Contains(t, out, `"kind":"rows"`) - assert.Contains(t, out, `"rows":`) + // Canonical key order: source, sql, kind, elapsed_ms, payload. + assert.Contains(t, out, `"source":"argv[1]","sql":"INSERT INTO t VALUES (1)","kind":"command","elapsed_ms":5,"command":"INSERT","rows_affected":1`) + assert.Contains(t, out, `"source":"argv[2]","sql":"SELECT id FROM t","kind":"rows","elapsed_ms":3,"rows":`) // Outer array framing. assert.Greater(t, len(out), 4) assert.Equal(t, byte('['), out[0]) @@ -78,12 +74,32 @@ func TestRenderJSONMulti_WithErrorAtEnd(t *testing.T) { CommandTag: "SELECT 1", Elapsed: 1 * time.Millisecond, } + errored := &unitResult{ + Source: "argv[2]", + SQL: "BROKEN SQL", + Elapsed: 2 * time.Millisecond, + } var stdout, stderr bytes.Buffer - require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1}, 1, "argv[2]: ERROR: syntax error (SQLSTATE 42601)")) + require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1}, errored, "ERROR: syntax error (SQLSTATE 42601)")) out := stdout.String() assert.Contains(t, out, `"kind":"rows"`) - assert.Contains(t, out, `"kind":"error"`) - assert.Contains(t, out, `"message":"argv[2]: ERROR: syntax error (SQLSTATE 42601)"`) + // Error envelope: same key order, includes elapsed_ms + source + sql. + assert.Contains(t, out, `"source":"argv[2]","sql":"BROKEN SQL","kind":"error","elapsed_ms":2,"error":{"message":"ERROR: syntax error (SQLSTATE 42601)"}`) +} + +func TestRenderJSONMulti_FirstUnitFails(t *testing.T) { + errored := &unitResult{ + Source: "argv[1]", + SQL: "BROKEN", + Elapsed: 7 * time.Millisecond, + } + var stdout, stderr bytes.Buffer + require.NoError(t, renderJSONMulti(&stdout, &stderr, nil, errored, "ERROR: bad")) + + out := stdout.String() + // No leading separator before the single error envelope. + assert.Contains(t, out, "[\n"+`{"source":"argv[1]","sql":"BROKEN","kind":"error","elapsed_ms":7,"error":{"message":"ERROR: bad"}}`) + assert.Contains(t, out, "\n]\n") } diff --git a/experimental/postgres/cmd/result.go b/experimental/postgres/cmd/result.go index d9b449a484..ec03534bfb 100644 --- a/experimental/postgres/cmd/result.go +++ b/experimental/postgres/cmd/result.go @@ -2,7 +2,6 @@ package postgrescmd import ( "context" - "fmt" "time" "github.com/jackc/pgx/v5" @@ -27,36 +26,44 @@ func (r *unitResult) IsRowsProducing() bool { return len(r.Fields) > 0 } -// runUnitBuffered runs sql and collects every row into memory. Used by the -// multi-input output paths (text and json), where per-unit buffering is -// acceptable per the plan: a multi-input invocation that emits a huge -// SELECT will buffer that result before printing. Users with huge result -// sets per statement should use single-input invocations (which fully -// stream) or --output csv on a single input. +// runUnitBuffered runs sql and collects every row into memory. Implemented +// as a thin wrapper that hands a bufferSink to executeOne, so error wrapping +// and the rowSink contract stay in one place rather than parallel-evolving +// across two query loops. func runUnitBuffered(ctx context.Context, conn *pgx.Conn, unit inputUnit) (*unitResult, error) { start := time.Now() - rows, err := conn.Query(ctx, unit.SQL, pgx.QueryExecModeExec) - if err != nil { - return nil, fmt.Errorf("query failed: %w", err) + r := &unitResult{Source: unit.Source, SQL: unit.SQL} + sink := &bufferSink{result: r} + if err := executeOne(ctx, conn, unit.SQL, sink); err != nil { + // Capture timing on the error path too so the JSON error envelope + // can surface "this query ran for X seconds before failing". + r.Elapsed = time.Since(start) + return r, err } - defer rows.Close() - - r := &unitResult{ - Source: unit.Source, - SQL: unit.SQL, - Fields: rows.FieldDescriptions(), - } - for rows.Next() { - values, err := rows.Values() - if err != nil { - return nil, fmt.Errorf("decode row: %w", err) - } - r.Rows = append(r.Rows, values) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("query failed: %w", err) - } - r.CommandTag = rows.CommandTag().String() r.Elapsed = time.Since(start) return r, nil } + +// bufferSink is a rowSink that copies fields, rows, and the command tag into +// a unitResult instead of writing anywhere. Used by the multi-input path so +// successive units can be rendered together once they're all available. +type bufferSink struct { + result *unitResult +} + +func (s *bufferSink) Begin(fields []pgconn.FieldDescription) error { + s.result.Fields = fields + return nil +} + +func (s *bufferSink) Row(values []any) error { + s.result.Rows = append(s.result.Rows, values) + return nil +} + +func (s *bufferSink) End(commandTag string) error { + s.result.CommandTag = commandTag + return nil +} + +func (s *bufferSink) OnError(err error) {} From 5f193b40dff126619cfa5bc14467f8fc4e7fcbc9 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:47:23 +0200 Subject: [PATCH 3/3] PR 3 r2: drop unreachable json-encoding fallback branches Round-2 reviewer noted jsonErrorObject's defensive branches around writeJSONUnitHeader/marshalJSON are unreachable (encoding/json doesn't error on string inputs), and the repo rule says drop "just in case" fallbacks. Replace with panic-on-impossible helpers. Co-authored-by: Isaac --- experimental/postgres/cmd/render_multi.go | 32 +++++++++++++++++------ 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/experimental/postgres/cmd/render_multi.go b/experimental/postgres/cmd/render_multi.go index 4cfa2063f7..2a2d793816 100644 --- a/experimental/postgres/cmd/render_multi.go +++ b/experimental/postgres/cmd/render_multi.go @@ -176,19 +176,35 @@ func writeJSONUnitHeader(buf *bytes.Buffer, r *unitResult) error { // time captured by runUnitBuffered's error path. message is the // already-formatted error wording (includes SQLSTATE / hint / detail for // PgErrors). +// +// marshalJSON of a string never errors (encoding/json replaces invalid UTF-8 +// with U+FFFD), so the inner errors are unreachable and we treat them as +// programming errors via panic. func jsonErrorObject(r *unitResult, message string) []byte { var buf bytes.Buffer - if err := writeJSONUnitHeader(&buf, r); err != nil { - return []byte(`{"source":"","sql":"","kind":"error","elapsed_ms":0,"error":{"message":""}}`) - } + mustWriteJSONHeader(&buf, r) buf.WriteString(`,"kind":"error"`) fmt.Fprintf(&buf, `,"elapsed_ms":%d`, r.Elapsed.Milliseconds()) buf.WriteString(`,"error":{"message":`) - if b, err := marshalJSON(message); err == nil { - buf.Write(b) - } else { - buf.WriteString(`""`) - } + buf.Write(mustMarshalJSON(message)) buf.WriteString(`}}`) return buf.Bytes() } + +// mustWriteJSONHeader is writeJSONUnitHeader with a panic instead of an +// error return. The only failure mode is an unreachable encoding/json error. +func mustWriteJSONHeader(buf *bytes.Buffer, r *unitResult) { + if err := writeJSONUnitHeader(buf, r); err != nil { + panic(fmt.Errorf("encoding json header: %w", err)) + } +} + +// mustMarshalJSON is marshalJSON with a panic instead of an error return, +// for the same reason. +func mustMarshalJSON(v any) []byte { + b, err := marshalJSON(v) + if err != nil { + panic(fmt.Errorf("encoding json value: %w", err)) + } + return b +}