diff --git a/cursor.go b/cursor.go new file mode 100644 index 0000000..749764e --- /dev/null +++ b/cursor.go @@ -0,0 +1,153 @@ +package pgkit + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + + sq "github.com/Masterminds/squirrel" + "github.com/lann/builder" +) + +var ( + // ErrInvalidCursor signals a client-supplied cursor that failed to decode - map to 400, not 500. + ErrInvalidCursor = errors.New("invalid cursor") + // ErrCursorQueryOrdered signals a cursor-paginated query that already had ORDER BY. + ErrCursorQueryOrdered = errors.New("cursor query already has order by") + // ErrCursorPageOrdered signals page-level ordering that does not match the cursor order. + ErrCursorPageOrdered = errors.New("cursor page order does not match cursor order") +) + +// EncodeCursor produces an opaque cursor: base64-JSON, not signed, never use it for authorization. +func EncodeCursor[C any](cursor C) (string, error) { + raw, err := json.Marshal(cursor) + if err != nil { + return "", fmt.Errorf("marshal cursor: %w", err) + } + return base64.RawURLEncoding.EncodeToString(raw), nil +} + +// DecodeCursor returns (nil, nil) for empty input so callers can compose with a nil-check. +func DecodeCursor[C any](value string) (*C, error) { + if value == "" { + return nil, nil + } + raw, err := base64.RawURLEncoding.DecodeString(value) + if err != nil { + return nil, ErrInvalidCursor + } + var cursor C + if err := json.Unmarshal(raw, &cursor); err != nil { + return nil, ErrInvalidCursor + } + return &cursor, nil +} + +// Cursor is the interface a typed keyset cursor satisfies — mirrors pgkit.Record[T, I]'s self-pointer pattern. +type Cursor[Self any, Row any] interface { + PtrTo[Self] + Apply(sq.SelectBuilder) sq.SelectBuilder + From(Row) error + // OrderBy must match Apply and should include a unique tiebreaker. + OrderBy() []Sort +} + +// CursorPaginator is the keyset sibling of Paginator[T] for ordering-stable pagination under concurrent writes. +type CursorPaginator[T any, C any, PC Cursor[C, T]] struct { + settings PaginatorSettings +} + +// NewCursorPaginator honors only size options - the cursor owns ORDER BY. +func NewCursorPaginator[T any, C any, PC Cursor[C, T]](options ...PaginatorOption) CursorPaginator[T, C, PC] { + settings := &PaginatorSettings{ + DefaultSize: DefaultPageSize, + MaxSize: MaxPageSize, + } + for _, option := range options { + option(settings) + } + if settings.MaxSize < settings.DefaultSize { + settings.MaxSize = settings.DefaultSize + } + return CursorPaginator[T, C, PC]{settings: *settings} +} + +// PrepareQuery chains LIMIT n+1 so PrepareResult can detect a next page without a second round-trip. +func (p CursorPaginator[T, C, PC]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.SelectBuilder, error) { + if page == nil { + page = &Page{} + } + page.SetDefaults(&p.settings) + + if _, ok := builder.Get(q, "OrderByParts"); ok { + return nil, q, ErrCursorQueryOrdered + } + var zero C + order := PC(&zero).OrderBy() + pageOrder := page.GetOrder(nil) + if len(pageOrder) != 0 && len(pageOrder) != len(order) { + return nil, q, ErrCursorPageOrdered + } + for i := range pageOrder { + if pageOrder[i] != order[i].sanitize(nil) { + return nil, q, ErrCursorPageOrdered + } + } + for _, sort := range order { + q = q.OrderBy(sort.String()) + } + if page.Cursor != "" { + cursor, err := DecodeCursor[C](page.Cursor) + if err != nil { + return nil, q, err + } + q = PC(cursor).Apply(q) + } + + limit := page.Limit() + q = q.Limit(limit + 1) + return make([]T, 0, limit+1), q, nil +} + +// Paginate returns cursor-paginated rows and the page populated with More and NextCursor. +func (p CursorPaginator[T, C, PC]) Paginate(ctx context.Context, query *Querier, q sq.SelectBuilder, page *Page) ([]T, *Page, error) { + if page == nil { + page = &Page{} + } + result, q, err := p.PrepareQuery(q, page) + if err != nil { + return nil, nil, err + } + if err := query.GetAll(ctx, q, &result); err != nil { + return nil, nil, err + } + result, err = p.PrepareResult(result, page) + if err != nil { + return nil, nil, err + } + return result, page, nil +} + +// PrepareResult must be called after GetAll to populate page.More and page.NextCursor. +func (p CursorPaginator[T, C, PC]) PrepareResult(result []T, page *Page) ([]T, error) { + limit := int(page.Limit()) + page.Size = uint32(limit) + page.More = len(result) > limit + if !page.More { + return result, nil + } + result = result[:limit] + + var cursor C + if err := PC(&cursor).From(result[len(result)-1]); err != nil { + return nil, fmt.Errorf("cursor from row: %w", err) + } + next, err := EncodeCursor(cursor) + if err != nil { + return nil, err + } + page.NextCursor = next + return result, nil +} diff --git a/cursor_test.go b/cursor_test.go new file mode 100644 index 0000000..a27571b --- /dev/null +++ b/cursor_test.go @@ -0,0 +1,328 @@ +package pgkit_test + +import ( + "errors" + "strconv" + "strings" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/goware/pgkit/v2" + "github.com/stretchr/testify/require" +) + +type row struct { + ID string +} + +type rowCursor struct { + ID string `json:"id"` +} + +func (c *rowCursor) Apply(q sq.SelectBuilder) sq.SelectBuilder { + return q.Where(sq.Lt{"id": c.ID}) +} + +func (c *rowCursor) OrderBy() []pgkit.Sort { + return []pgkit.Sort{{Column: "id", Order: pgkit.Desc}} +} + +func (c *rowCursor) From(r row) error { + c.ID = r.ID + return nil +} + +func TestEncodeDecodeCursorRoundTrip(t *testing.T) { + encoded, err := pgkit.EncodeCursor(rowCursor{ID: "row_1"}) + require.NoError(t, err) + require.NotEmpty(t, encoded) + + decoded, err := pgkit.DecodeCursor[rowCursor](encoded) + require.NoError(t, err) + require.NotNil(t, decoded) + require.Equal(t, "row_1", decoded.ID) +} + +func TestDecodeCursorEmptyReturnsNil(t *testing.T) { + decoded, err := pgkit.DecodeCursor[rowCursor]("") + require.NoError(t, err) + require.Nil(t, decoded) +} + +func TestDecodeCursorInvalidBase64(t *testing.T) { + _, err := pgkit.DecodeCursor[rowCursor]("!!!not-base64!!!") + require.Error(t, err) + require.True(t, errors.Is(err, pgkit.ErrInvalidCursor)) +} + +func TestDecodeCursorInvalidJSON(t *testing.T) { + encoded, err := pgkit.EncodeCursor("not a struct") + require.NoError(t, err) + + _, err = pgkit.DecodeCursor[rowCursor](encoded) + require.Error(t, err) + require.True(t, errors.Is(err, pgkit.ErrInvalidCursor)) +} + +func TestCursorPaginatorFirstPage(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]( + pgkit.WithDefaultSize(2), + pgkit.WithMaxSize(5), + ) + page := &pgkit.Page{} + + result, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + require.Len(t, result, 0) + require.Equal(t, 3, cap(result)) + + sql, args, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, `SELECT * FROM t ORDER BY "id" DESC LIMIT 3`, sql) + require.Empty(t, args) +} + +func TestCursorPaginatorWithCursor(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(2)) + encoded, err := pgkit.EncodeCursor(rowCursor{ID: "row_5"}) + require.NoError(t, err) + page := &pgkit.Page{Cursor: encoded} + + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + sql, args, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, `SELECT * FROM t WHERE id < ? ORDER BY "id" DESC LIMIT 3`, sql) + require.Equal(t, []any{"row_5"}, args) +} + +func TestCursorPaginatorRejectsPreorderedQuery(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]() + + _, _, err := paginator.PrepareQuery(sq.Select("*").From("t").OrderBy("name"), &pgkit.Page{}) + require.ErrorIs(t, err, pgkit.ErrCursorQueryOrdered) +} + +func TestCursorPaginatorAllowsMatchingPageOrder(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]() + + tests := []struct { + name string + page *pgkit.Page + }{ + { + name: "sort", + page: &pgkit.Page{Sort: []pgkit.Sort{{Column: "id", Order: pgkit.Desc}}}, + }, + { + name: "column", + page: &pgkit.Page{Column: "-id"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), tt.page) + require.NoError(t, err) + + sql, args, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, `SELECT * FROM t ORDER BY "id" DESC LIMIT 11`, sql) + require.Empty(t, args) + }) + } +} + +func TestCursorPaginatorRejectsMismatchedPageOrder(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]() + + tests := []struct { + name string + page *pgkit.Page + }{ + { + name: "sort", + page: &pgkit.Page{Sort: []pgkit.Sort{{Column: "name", Order: pgkit.Asc}}}, + }, + { + name: "column", + page: &pgkit.Page{Column: "name"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), tt.page) + require.ErrorIs(t, err, pgkit.ErrCursorPageOrdered) + }) + } +} + +func TestCursorPaginatorInvalidCursor(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]() + page := &pgkit.Page{Cursor: "!!!not-base64!!!"} + + _, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.Error(t, err) + require.True(t, errors.Is(err, pgkit.ErrInvalidCursor)) +} + +func TestCursorPaginatorPrepareResultNoMore(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(3)) + page := &pgkit.Page{} + _, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + result, err := paginator.PrepareResult([]row{{ID: "1"}, {ID: "2"}}, page) + require.NoError(t, err) + require.Len(t, result, 2) + require.False(t, page.More) + require.Empty(t, page.NextCursor) + require.Equal(t, uint32(3), page.Size) +} + +func TestCursorPaginatorPrepareResultHasMore(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(2)) + page := &pgkit.Page{} + _, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + result, err := paginator.PrepareResult( + []row{{ID: "3"}, {ID: "2"}, {ID: "1"}}, + page, + ) + require.NoError(t, err) + require.Equal(t, []row{{ID: "3"}, {ID: "2"}}, result) + require.True(t, page.More) + require.NotEmpty(t, page.NextCursor) + + decoded, err := pgkit.DecodeCursor[rowCursor](page.NextCursor) + require.NoError(t, err) + require.NotNil(t, decoded) + require.Equal(t, "2", decoded.ID) +} + +func TestCursorPaginatorDefaultsFromNilPage(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]() + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), nil) + require.NoError(t, err) + + sql, _, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, `SELECT * FROM t ORDER BY "id" DESC LIMIT 11`, sql) +} + +func TestCursorPaginatorCapsAtMaxSize(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]( + pgkit.WithDefaultSize(5), + pgkit.WithMaxSize(10), + ) + page := &pgkit.Page{Size: 999} + + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + sql, _, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, `SELECT * FROM t ORDER BY "id" DESC LIMIT 11`, sql) + require.Equal(t, uint32(10), page.Size) +} + +func TestCursorPaginatorMaxSizeBelowDefaultIsLifted(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor]( + pgkit.WithDefaultSize(20), + pgkit.WithMaxSize(5), + ) + page := &pgkit.Page{} + + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + sql, _, err := q.ToSql() + require.NoError(t, err) + require.Equal(t, `SELECT * FROM t ORDER BY "id" DESC LIMIT 21`, sql) +} + +func TestCursorPaginatorWalksPages(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, rowCursor, *rowCursor](pgkit.WithDefaultSize(2)) + all := []row{{ID: "5"}, {ID: "4"}, {ID: "3"}, {ID: "2"}, {ID: "1"}} + + var ( + page = &pgkit.Page{} + seen []row + ) + for step := 0; step < 5; step++ { + _, q, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + fetched := fetch(t, all, q) + got, err := paginator.PrepareResult(fetched, page) + require.NoError(t, err) + + seen = append(seen, got...) + if !page.More { + break + } + page.Cursor = page.NextCursor + page.NextCursor = "" + } + require.Equal(t, all, seen) + require.False(t, page.More) +} + +type failingRowCursor struct { + ID string `json:"id"` +} + +func (c *failingRowCursor) Apply(q sq.SelectBuilder) sq.SelectBuilder { + return q.Where(sq.Lt{"id": c.ID}) +} + +func (c *failingRowCursor) OrderBy() []pgkit.Sort { + return []pgkit.Sort{{Column: "id", Order: pgkit.Desc}} +} + +var errBoom = errors.New("boom") + +func (c *failingRowCursor) From(row) error { + return errBoom +} + +func TestCursorPaginatorPrepareResultPropagatesCursorError(t *testing.T) { + paginator := pgkit.NewCursorPaginator[row, failingRowCursor, *failingRowCursor](pgkit.WithDefaultSize(1)) + page := &pgkit.Page{} + _, _, err := paginator.PrepareQuery(sq.Select("*").From("t"), page) + require.NoError(t, err) + + _, err = paginator.PrepareResult([]row{{ID: "2"}, {ID: "1"}}, page) + require.Error(t, err) + require.True(t, errors.Is(err, errBoom)) +} + +// In-memory stand-in so the pagination walk exercises encode/decode without a real database. +func fetch(t *testing.T, all []row, q sq.SelectBuilder) []row { + t.Helper() + sql, args, err := q.ToSql() + require.NoError(t, err) + + limit, err := strconv.Atoi(sql[strings.LastIndex(sql, " ")+1:]) + require.NoError(t, err) + + cutoff := "" + if len(args) == 1 { + cutoff = args[0].(string) + } + + out := make([]row, 0, limit) + for _, r := range all { + if cutoff != "" && r.ID >= cutoff { + continue + } + out = append(out, r) + if len(out) == limit { + break + } + } + return out +} diff --git a/go.mod b/go.mod index 7d1e982..9f9ce1c 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/georgysavva/scany/v2 v2.1.4 github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6 github.com/jackc/pgx/v5 v5.9.0 + github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 github.com/stretchr/testify v1.11.1 ) @@ -16,7 +17,6 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/kr/text v0.2.0 // indirect - github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect github.com/lib/pq v1.10.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/page.go b/page.go index 35621e7..3b99f55 100644 --- a/page.go +++ b/page.go @@ -1,6 +1,7 @@ package pgkit import ( + "context" "fmt" "regexp" "slices" @@ -81,6 +82,10 @@ type Page struct { More bool Column string Sort []Sort + + // Unused by the offset Paginator — shared here so callers can swap paginators without changing the Page type. + Cursor string + NextCursor string } func NewPage(size, page uint32, sort ...Sort) *Page { @@ -251,6 +256,18 @@ func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.Sele return make([]T, 0, limit+1), q } +// Paginate returns offset-paginated rows and the page populated with More. +func (p Paginator[T]) Paginate(ctx context.Context, query *Querier, q sq.SelectBuilder, page *Page) ([]T, *Page, error) { + if page == nil { + page = &Page{} + } + result, q := p.PrepareQuery(q, page) + if err := query.GetAll(ctx, q, &result); err != nil { + return nil, nil, err + } + return p.PrepareResult(result, page), page, nil +} + func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, []any) { if page == nil { page = &Page{} diff --git a/table.go b/table.go index b185f46..f2029a3 100644 --- a/table.go +++ b/table.go @@ -15,9 +15,14 @@ import ( // ID is a comparable type used for record IDs. type ID comparable +// PtrTo constrains a type parameter to be a pointer to T. +type PtrTo[T any] interface { + *T +} + // Records must be a pointer with the methods defined on the pointer. type Record[T any, I ID] interface { - *T // Enforce T is a pointer. + PtrTo[T] GetID() I Validate() error } diff --git a/tests/cursor_test.go b/tests/cursor_test.go new file mode 100644 index 0000000..14f531c --- /dev/null +++ b/tests/cursor_test.go @@ -0,0 +1,119 @@ +package pgkit_test + +import ( + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/goware/pgkit/v2" + "github.com/stretchr/testify/require" +) + +type articleCursor struct { + ID uint64 `json:"id"` +} + +func (c *articleCursor) Apply(q sq.SelectBuilder) sq.SelectBuilder { + return q.Where(sq.Lt{"id": c.ID}) +} + +func (c *articleCursor) OrderBy() []pgkit.Sort { + return []pgkit.Sort{{Column: "id", Order: pgkit.Desc}} +} + +func (c *articleCursor) From(article *Article) error { + c.ID = article.ID + return nil +} + +func TestCursorPaginatorPaginateReturnsPage(t *testing.T) { + ctx := t.Context() + db := initDB(DB) + + account := &Account{Name: "CursorPaginatorPaginate Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err) + + for range 5 { + err := db.Articles.Save(ctx, &Article{ + AccountID: account.ID, + Author: "Cursor Author", + }) + require.NoError(t, err) + } + + paginator := pgkit.NewCursorPaginator[*Article, articleCursor, *articleCursor]( + pgkit.WithDefaultSize(2), + ) + q := db.SQL.Select("*"). + From("articles"). + Where(sq.Eq{"account_id": account.ID}) + + first, firstPage, err := paginator.Paginate(ctx, db.Query, q, nil) + require.NoError(t, err) + require.Len(t, first, 2) + require.NotNil(t, firstPage) + require.Equal(t, uint32(2), firstPage.Size) + require.True(t, firstPage.More) + require.NotEmpty(t, firstPage.NextCursor) + + page := &pgkit.Page{ + Cursor: firstPage.NextCursor, + } + second, secondPage, err := paginator.Paginate(ctx, db.Query, q, page) + require.NoError(t, err) + require.Len(t, second, 2) + require.Same(t, page, secondPage) + require.True(t, secondPage.More) + require.NotEmpty(t, secondPage.NextCursor) + + for _, a := range first { + for _, b := range second { + require.NotEqual(t, a.ID, b.ID, "cursor pages should not overlap") + } + } +} + +func TestPaginatorPaginateReturnsPage(t *testing.T) { + ctx := t.Context() + db := initDB(DB) + + account := &Account{Name: "PaginatorPaginate Account"} + err := db.Accounts.Save(ctx, account) + require.NoError(t, err) + + for range 5 { + err := db.Articles.Save(ctx, &Article{ + AccountID: account.ID, + Author: "Offset Author", + }) + require.NoError(t, err) + } + + paginator := pgkit.NewPaginator[*Article](pgkit.WithDefaultSize(2)) + q := db.SQL.Select("*"). + From("articles"). + Where(sq.Eq{"account_id": account.ID}) + + first, firstPage, err := paginator.Paginate(ctx, db.Query, q, nil) + require.NoError(t, err) + require.Len(t, first, 2) + require.NotNil(t, firstPage) + require.Equal(t, uint32(2), firstPage.Size) + require.Equal(t, uint32(1), firstPage.Page) + require.True(t, firstPage.More) + + page := &pgkit.Page{Page: 2} + second, secondPage, err := paginator.Paginate(ctx, db.Query, q, page) + require.NoError(t, err) + require.Len(t, second, 2) + require.Same(t, page, secondPage) + require.Equal(t, uint32(2), secondPage.Size) + require.Equal(t, uint32(2), secondPage.Page) + require.True(t, secondPage.More) + + for _, a := range first { + for _, b := range second { + require.NotEqual(t, a.ID, b.ID, "offset pages should not overlap") + } + } +}