From ab91c423c0b54dd74894c6f720d8b8229b5048c5 Mon Sep 17 00:00:00 2001 From: Rohan-R07 Date: Sat, 13 Jun 2026 10:34:59 +0530 Subject: [PATCH 1/2] feat: implement automated database replication and fix RLS subquery & test issues --- README.md | 17 ++ dist/plugins.ts | 1 + plugins/replication/index.test.ts | 284 +++++++++++++++++ plugins/replication/index.ts | 485 ++++++++++++++++++++++++++++++ src/index.ts | 8 + src/rls/index.test.ts | 27 +- src/rls/index.ts | 90 ++++-- wrangler.toml | 8 +- 8 files changed, 881 insertions(+), 39 deletions(-) create mode 100644 plugins/replication/index.test.ts create mode 100644 plugins/replication/index.ts diff --git a/README.md b/README.md index 5931b1c..949f9d8 100644 --- a/README.md +++ b/README.md @@ -270,6 +270,23 @@ curl --location 'https://starbasedb.YOUR-ID-HERE.workers.dev/import/dump' \ +

Database Replication

+

+ StarbaseDB includes an automated data replication plugin to synchronize data from external databases (such as PostgreSQL or MySQL) into StarbaseDB's internal SQLite. It runs as a scheduled background job, utilizing Cloudflare Durable Object Alarms. +

+

+ Configure the following variables in your wrangler.toml to enable replication: +

+
+
+[vars]
+EXTERNAL_DB_TYPE = "postgresql"
+EXTERNAL_DB_TABLES_TO_TRACK = "users,posts,comments" # Comma-separated list of tables to track
+EXTERNAL_DB_POLLING_INTERVAL = "*/1 * * * *"         # Polling interval (CRON expression)
+EXTERNAL_DB_BATCH_SIZE = 500                         # Number of rows per batch
+
+
+

Contributing

We welcome contributions! Please refer to our Contribution Guide for more details.

diff --git a/dist/plugins.ts b/dist/plugins.ts index 7dd252a..0a5cac5 100644 --- a/dist/plugins.ts +++ b/dist/plugins.ts @@ -6,3 +6,4 @@ export { ChangeDataCapturePlugin } from '../plugins/cdc' export { QueryLogPlugin } from '../plugins/query-log' export { ResendPlugin } from '../plugins/resend' export { ClerkPlugin } from '../plugins/clerk' +export { ReplicationPlugin } from '../plugins/replication' diff --git a/plugins/replication/index.test.ts b/plugins/replication/index.test.ts new file mode 100644 index 0000000..6bea53e --- /dev/null +++ b/plugins/replication/index.test.ts @@ -0,0 +1,284 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { ReplicationPlugin } from './index' +import { executeSDKQuery } from '../../src/operation' + +// Mock the operation module to mock executeSDKQuery +vi.mock('../../src/operation', () => ({ + executeSDKQuery: vi.fn(), +})) + +describe('ReplicationPlugin', () => { + let plugin: ReplicationPlugin + let mockCronPlugin: any + let mockDataSource: any + let mockConfig: any + + beforeEach(() => { + vi.clearAllMocks() + + mockCronPlugin = { + onEvent: vi.fn(), + addEvent: vi.fn(), + } + + mockDataSource = { + external: { + dialect: 'postgresql', + defaultSchema: 'public', + database: 'testdb', + }, + rpc: { + executeQuery: vi.fn().mockResolvedValue([]), + executeTransaction: vi.fn().mockResolvedValue([]), + setAlarm: vi.fn().mockResolvedValue(undefined), + }, + } + + mockConfig = { + role: 'client', + } + + plugin = new ReplicationPlugin({ cronPlugin: mockCronPlugin }) + plugin['dataSource'] = mockDataSource + plugin['config'] = mockConfig + plugin['env'] = { + EXTERNAL_DB_TYPE: 'postgresql', + EXTERNAL_DB_TABLES_TO_TRACK: 'users', + EXTERNAL_DB_BATCH_SIZE: 5, + } + }) + + it('should initialize correctly and register event callback', async () => { + const mockApp = { + use: vi.fn(), + } as any + + await plugin.register(mockApp) + + expect(mockApp.use).toHaveBeenCalledTimes(1) + expect(mockCronPlugin.onEvent).toHaveBeenCalledTimes(1) + }) + + it('should not run replication if already running', async () => { + plugin['isRunning'] = true + const logSpy = vi.spyOn(console, 'log').mockImplementation(() => {}) + + await plugin.runReplication() + + expect(logSpy).toHaveBeenCalledWith( + 'Database replication is already running. Skipping.' + ) + logSpy.mockRestore() + }) + + it('should map postgres schema and create local table if it does not exist', async () => { + // Mock query returned columns: id (INT), name (TEXT), updated_at (TIMESTAMP) + vi.mocked(executeSDKQuery).mockResolvedValue([ + { column_name: 'id', data_type: 'integer', is_nullable: 'NO' }, + { column_name: 'name', data_type: 'text', is_nullable: 'YES' }, + { + column_name: 'updated_at', + data_type: 'timestamp', + is_nullable: 'YES', + }, + ] as any) + + mockDataSource.rpc.executeQuery.mockResolvedValue([]) + + const moreData = await plugin['replicateTable']('users', 5) + + expect(moreData).toBe(false) + expect(executeSDKQuery).toHaveBeenCalledWith( + expect.objectContaining({ + sql: expect.stringContaining('information_schema.columns'), + }) + ) + expect(mockDataSource.rpc.executeQuery).toHaveBeenCalledWith( + expect.objectContaining({ + sql: expect.stringContaining( + 'CREATE TABLE IF NOT EXISTS "users"' + ), + }) + ) + }) + + it('should perform incremental polling using updated_at and id tie-breaker', async () => { + // Schema lookup mock + vi.mocked(executeSDKQuery).mockImplementation(async (opts: any) => { + if (opts.sql.includes('information_schema.columns')) { + return [ + { + column_name: 'id', + data_type: 'integer', + is_nullable: 'NO', + }, + { + column_name: 'name', + data_type: 'text', + is_nullable: 'YES', + }, + { + column_name: 'updated_at', + data_type: 'timestamp', + is_nullable: 'YES', + }, + ] + } + if (opts.sql.includes('SELECT * FROM')) { + // Mock return rows + return [ + { + id: 1, + name: 'Alice', + updated_at: '2026-06-13T10:00:00.000Z', + }, + { + id: 2, + name: 'Bob', + updated_at: '2026-06-13T10:05:00.000Z', + }, + ] + } + return [] + }) + + // Mock state retrieval (already has a watermark) + mockDataSource.rpc.executeQuery.mockImplementation( + async (opts: any) => { + if ( + opts.sql.includes('SELECT last_synced_id, last_synced_at') + ) { + return [ + { + last_synced_id: 1, + last_synced_at: '2026-06-13T09:00:00.000Z', + }, + ] + } + return [] + } + ) + + const moreData = await plugin['replicateTable']('users', 5) + + // Since rows count (2) is less than batchSize (5), it should be false (no more data) + expect(moreData).toBe(false) + + // Verify it queries with the correct parameters + expect(executeSDKQuery).toHaveBeenCalledWith( + expect.objectContaining({ + params: [ + '2026-06-13T09:00:00.000Z', + '2026-06-13T09:00:00.000Z', + 1, + 5, + ], + }) + ) + + // Verify transaction insert was run + expect(mockDataSource.rpc.executeTransaction).toHaveBeenCalledTimes(1) + expect( + mockDataSource.rpc.executeTransaction.mock.calls[0][0] + ).toHaveLength(2) + + // Verify replication state was updated + expect(mockDataSource.rpc.executeQuery).toHaveBeenCalledWith( + expect.objectContaining({ + sql: expect.stringContaining( + 'INSERT OR REPLACE INTO tmp_replication_state' + ), + params: ['users', '2', '2026-06-13T10:05:00.000Z'], + }) + ) + }) + + it('should map values correctly to SQLite compatible types', async () => { + // Mock columns + vi.mocked(executeSDKQuery).mockImplementation(async (opts: any) => { + if (opts.sql.includes('information_schema.columns')) { + return [ + { + column_name: 'id', + data_type: 'integer', + is_nullable: 'NO', + }, + { + column_name: 'metadata', + data_type: 'jsonb', + is_nullable: 'YES', + }, + { + column_name: 'is_active', + data_type: 'boolean', + is_nullable: 'YES', + }, + ] + } + if (opts.sql.includes('SELECT * FROM')) { + return [{ id: 1, metadata: { role: 'admin' }, is_active: true }] + } + return [] + }) + + mockDataSource.rpc.executeQuery.mockResolvedValue([]) + + await plugin['replicateTable']('users', 5) + + // Verify transaction mapped the object metadata and boolean is_active + expect(mockDataSource.rpc.executeTransaction).toHaveBeenCalledTimes(1) + const queries = mockDataSource.rpc.executeTransaction.mock.calls[0][0] + expect(queries[0].params).toEqual([1, '{"role":"admin"}', 1]) + }) + + it('should sync hard deletions', async () => { + // Mock columns + vi.mocked(executeSDKQuery).mockImplementation(async (opts: any) => { + if (opts.sql.includes('information_schema.columns')) { + return [ + { + column_name: 'id', + data_type: 'integer', + is_nullable: 'NO', + }, + ] + } + // First select rows: return empty to trigger deletion sync + if (opts.sql.includes('SELECT * FROM')) { + return [] + } + // ID query from external: returns IDs 1 and 2 + if (opts.sql.includes('SELECT "id" FROM')) { + return [{ id: 1 }, { id: 2 }] + } + return [] + }) + + // SQLite local IDs: returns 1, 2, and 3 (3 was deleted in source) + mockDataSource.rpc.executeQuery.mockImplementation( + async (opts: any) => { + if ( + opts.sql.includes('SELECT last_synced_id, last_synced_at') + ) { + return [] + } + if (opts.sql.includes('SELECT "id" FROM "users"')) { + return [{ id: 1 }, { id: 2 }, { id: 3 }] + } + return [] + } + ) + + await plugin['replicateTable']('users', 5) + + // Verify DELETE query was executed for ID 3 + expect(mockDataSource.rpc.executeQuery).toHaveBeenCalledWith( + expect.objectContaining({ + sql: expect.stringContaining( + 'DELETE FROM "users" WHERE "id" IN (?)' + ), + params: [3], + }) + ) + }) +}) diff --git a/plugins/replication/index.ts b/plugins/replication/index.ts new file mode 100644 index 0000000..47c47d0 --- /dev/null +++ b/plugins/replication/index.ts @@ -0,0 +1,485 @@ +import { StarbaseApp, StarbaseDBConfiguration } from '../../src/handler' +import { StarbasePlugin } from '../../src/plugin' +import { DataSource } from '../../src/types' +import { executeSDKQuery } from '../../src/operation' +import { CronPlugin } from '../cron' + +export class ReplicationPlugin extends StarbasePlugin { + private cronPlugin: CronPlugin + private dataSource?: DataSource + private config?: StarbaseDBConfiguration + private env: any + private isRunning: boolean = false + + constructor(opts: { cronPlugin: CronPlugin }) { + super('starbasedb:replication', { + requiresAuth: false, + }) + this.cronPlugin = opts.cronPlugin + } + + override async register(app: StarbaseApp) { + // Hono middleware to intercept and initialize context + app.use(async (c, next) => { + this.dataSource = c.get('dataSource') + this.config = c.get('config') + this.env = c.env as any + + if ( + this.env?.EXTERNAL_DB_TYPE && + this.env?.EXTERNAL_DB_TABLES_TO_TRACK + ) { + await this.initReplicationState() + await this.registerCronTask(c) + } + + await next() + }) + + // Listen for the replication cron event + this.cronPlugin.onEvent(async (event) => { + if (event.name === 'database-replication') { + await this.runReplication() + } + }) + } + + /** + * Initialize the SQLite table that keeps track of replication watermark state + */ + private async initReplicationState() { + if (!this.dataSource) return + + const createTableSQL = ` + CREATE TABLE IF NOT EXISTS tmp_replication_state ( + table_name TEXT PRIMARY KEY, + last_synced_id TEXT, + last_synced_at TEXT, + is_syncing INTEGER DEFAULT 0 + ) + ` + await this.dataSource.rpc.executeQuery({ + sql: createTableSQL, + params: [], + }) + } + + /** + * Register the database-replication task in tmp_cron_tasks if it doesn't exist + */ + private async registerCronTask(c: any) { + if (!this.dataSource) return + + const tables = this.env.EXTERNAL_DB_TABLES_TO_TRACK + if (!tables) return + + const pollingInterval = + this.env.EXTERNAL_DB_POLLING_INTERVAL || '*/1 * * * *' + const url = new URL(c.req.url) + const callbackHost = `${url.protocol}//${url.host}` + + // Check if task already exists with same config + const result = (await this.dataSource.rpc.executeQuery({ + sql: 'SELECT name, cron_tab, callback_host FROM tmp_cron_tasks WHERE name = ?', + params: ['database-replication'], + })) as any[] + + if (result && result.length > 0) { + const task = result[0] + if ( + task.cron_tab === pollingInterval && + task.callback_host === callbackHost + ) { + // Already configured correctly, do not overwrite to prevent resetting alarm execution + return + } + } + + // Add event using cron plugin + await this.cronPlugin.addEvent( + pollingInterval, + 'database-replication', + {}, + callbackHost + ) + } + + /** + * Run replication logic for all configured tables + */ + public async runReplication() { + if (this.isRunning) { + console.log('Database replication is already running. Skipping.') + return + } + + if (!this.dataSource || !this.config || !this.env) { + console.warn('ReplicationPlugin not properly initialized.') + return + } + + const tablesToTrack = this.env.EXTERNAL_DB_TABLES_TO_TRACK + if (!tablesToTrack) return + + const tables = tablesToTrack.split(',').map((t: string) => t.trim()) + const batchSize = Number(this.env.EXTERNAL_DB_BATCH_SIZE) || 500 + + this.isRunning = true + let needsFollowUp = false + + try { + for (const table of tables) { + try { + const moreData = await this.replicateTable(table, batchSize) + if (moreData) { + needsFollowUp = true + } + } catch (error) { + console.error(`Error replicating table ${table}:`, error) + } + } + + // Yield control/respect Cloudflare limits: If more data needs to be synced, + // reschedule the alarm to run again immediately (in 1 second) + if (needsFollowUp) { + console.log( + 'More data available for replication. Rescheduling alarm in 1s.' + ) + await this.dataSource.rpc.setAlarm(Date.now() + 1000) + } + } finally { + this.isRunning = false + } + } + + /** + * Replicates a single table's batch. Returns true if there is more data to sync. + */ + private async replicateTable( + tableName: string, + batchSize: number + ): Promise { + const external = this.dataSource?.external + if (!external) { + throw new Error('No external database connection configured.') + } + + const dialect = external.dialect + + // 1. Get columns and types of external table + let schemaQuery = '' + let schemaParams: any[] = [] + + if (dialect === 'postgresql') { + schemaQuery = ` + SELECT column_name, data_type, is_nullable + FROM information_schema.columns + WHERE table_schema = $1 AND table_name = $2 + ` + schemaParams = [external.defaultSchema || 'public', tableName] + } else if (dialect === 'mysql') { + schemaQuery = ` + SELECT column_name as COLUMN_NAME, data_type as DATA_TYPE, is_nullable as IS_NULLABLE + FROM information_schema.columns + WHERE table_schema = ? AND table_name = ? + ` + schemaParams = [external.database, tableName] + } else { + // sqlite + schemaQuery = `PRAGMA table_info(${tableName})` + schemaParams = [] + } + + const schemaResult = await executeSDKQuery({ + sql: schemaQuery, + params: schemaParams, + dataSource: this.dataSource!, + config: this.config!, + }) + + if (!schemaResult || schemaResult.length === 0) { + throw new Error( + `Could not retrieve schema for external table ${tableName}` + ) + } + + // Normalize schema format + let columns: { + name: string + type: string + nullable: boolean + isPrimaryKey?: boolean + }[] = [] + if (dialect === 'postgresql' || dialect === 'mysql') { + columns = schemaResult.map((r: any) => ({ + name: r.column_name || r.COLUMN_NAME, + type: r.data_type || r.DATA_TYPE, + nullable: (r.is_nullable || r.IS_NULLABLE) === 'YES', + isPrimaryKey: (r.column_name || r.COLUMN_NAME) === 'id', + })) + } else { + // sqlite + columns = schemaResult.map((r: any) => ({ + name: r.name, + type: r.type, + nullable: r.notnull === 0, + isPrimaryKey: r.pk > 0, + })) + } + + // 2. Ensure local table exists + const columnDefs = columns.map((c) => { + let def = `"${c.name}" ${this.mapToSQLiteType(c.type)}` + if (c.isPrimaryKey) { + def += ' PRIMARY KEY' + } + if (!c.nullable) { + def += ' NOT NULL' + } + return def + }) + + const createTableSQL = `CREATE TABLE IF NOT EXISTS "${tableName}" (${columnDefs.join(', ')})` + await this.dataSource!.rpc.executeQuery({ sql: createTableSQL }) + + // Check for missing columns in local table (dynamic schema migration) + const localInfo = (await this.dataSource!.rpc.executeQuery({ + sql: `PRAGMA table_info("${tableName}")`, + params: [], + })) as any[] + + if (localInfo && localInfo.length > 0) { + const localColNames = new Set( + localInfo.map((r: any) => r.name.toLowerCase()) + ) + for (const col of columns) { + if (!localColNames.has(col.name.toLowerCase())) { + let alterSQL = `ALTER TABLE "${tableName}" ADD COLUMN "${col.name}" ${this.mapToSQLiteType(col.type)}` + console.log( + `Adding missing column ${col.name} to local table ${tableName}` + ) + await this.dataSource!.rpc.executeQuery({ sql: alterSQL }) + } + } + } + + // 3. Read replication state (last_synced_id / last_synced_at) + const stateResult = (await this.dataSource!.rpc.executeQuery({ + sql: 'SELECT last_synced_id, last_synced_at FROM tmp_replication_state WHERE table_name = ?', + params: [tableName], + })) as any[] + + let lastSyncedId: any = null + let lastSyncedAt: string | null = null + + if (stateResult && stateResult.length > 0) { + lastSyncedId = stateResult[0].last_synced_id + lastSyncedAt = stateResult[0].last_synced_at || null + } + + // 4. Build polling query based on columns + const pkCol = + columns.find((c) => c.isPrimaryKey)?.name || + columns.find((c) => c.name.toLowerCase() === 'id')?.name + const hasIntId = columns.some( + (c) => + c.name.toLowerCase() === 'id' && + this.mapToSQLiteType(c.type) === 'INTEGER' + ) + const hasUpdatedAt = columns.some( + (c) => c.name.toLowerCase() === 'updated_at' + ) + + let fetchSQL = '' + let fetchParams: any[] = [] + + if (hasUpdatedAt && pkCol) { + const timeVal = lastSyncedAt || '1970-01-01T00:00:00.000Z' + if (!lastSyncedAt) { + if (dialect === 'postgresql') { + fetchSQL = `SELECT * FROM "${tableName}" ORDER BY updated_at ASC, "${pkCol}" ASC LIMIT $1` + } else { + fetchSQL = `SELECT * FROM \`${tableName}\` ORDER BY updated_at ASC, \`${pkCol}\` ASC LIMIT ?` + } + fetchParams = [batchSize] + } else { + if (dialect === 'postgresql') { + fetchSQL = `SELECT * FROM "${tableName}" WHERE updated_at > $1 OR (updated_at = $2 AND "${pkCol}" > $3) ORDER BY updated_at ASC, "${pkCol}" ASC LIMIT $4` + } else { + fetchSQL = `SELECT * FROM \`${tableName}\` WHERE updated_at > ? OR (updated_at = ? AND \`${pkCol}\` > ?) ORDER BY updated_at ASC, \`${pkCol}\` ASC LIMIT ?` + } + fetchParams = [timeVal, timeVal, lastSyncedId || '', batchSize] + } + } else if (hasIntId) { + if (dialect === 'postgresql') { + fetchSQL = `SELECT * FROM "${tableName}" WHERE id > $1 ORDER BY id ASC LIMIT $2` + } else { + fetchSQL = `SELECT * FROM \`${tableName}\` WHERE id > ? ORDER BY id ASC LIMIT ?` + } + fetchParams = [Number(lastSyncedId || 0), batchSize] + } else { + // Full scan fallback + if (dialect === 'postgresql') { + fetchSQL = `SELECT * FROM "${tableName}" LIMIT $1` + } else { + fetchSQL = `SELECT * FROM \`${tableName}\` LIMIT ?` + } + fetchParams = [batchSize] + } + + const rows = await executeSDKQuery({ + sql: fetchSQL, + params: fetchParams, + dataSource: this.dataSource!, + config: this.config!, + }) + + if (!rows || rows.length === 0) { + // Sync hard deletions when table is up to date + await this.syncDeletions(tableName, columns) + return false + } + + // 5. Schema mapping of values and insert into local SQLite + const colNames = columns.map((c) => c.name) + const placeholders = colNames.map(() => '?').join(', ') + const insertSQL = `INSERT OR REPLACE INTO "${tableName}" (${colNames.map((c) => `"${c}"`).join(', ')}) VALUES (${placeholders})` + + const insertQueries = rows.map((row: any) => { + const params = colNames.map((col) => { + let val = row[col] + if (val === undefined) { + val = null + } + if (val !== null && typeof val === 'object') { + if (val instanceof Date) { + return val.toISOString() + } + return JSON.stringify(val) + } + if (typeof val === 'boolean') { + return val ? 1 : 0 + } + if (typeof val === 'bigint') { + return Number(val) + } + return val + }) + return { sql: insertSQL, params } + }) + + await this.dataSource!.rpc.executeTransaction(insertQueries, false) + + // 6. Update replication watermark state + let nextSyncedId = lastSyncedId + let nextSyncedAt = lastSyncedAt + + const lastRow = rows[rows.length - 1] + if (pkCol && lastRow[pkCol] !== undefined) { + nextSyncedId = String(lastRow[pkCol]) + } + if (hasUpdatedAt && lastRow.updated_at !== undefined) { + const val = lastRow.updated_at + nextSyncedAt = val ? new Date(val).toISOString() : null + } + + await this.dataSource!.rpc.executeQuery({ + sql: `INSERT OR REPLACE INTO tmp_replication_state (table_name, last_synced_id, last_synced_at, is_syncing) VALUES (?, ?, ?, 0)`, + params: [tableName, nextSyncedId, nextSyncedAt], + }) + + // Return true if we hit the batch limit, implying more data might be available + return rows.length === batchSize + } + + /** + * Handles hard deletions by comparing active IDs in the source to SQLite IDs + */ + private async syncDeletions(tableName: string, columns: any[]) { + const pkCol = + columns.find((c) => c.isPrimaryKey)?.name || + columns.find((c) => c.name.toLowerCase() === 'id')?.name + if (!pkCol) return // Skip if table has no primary key column + + const external = this.dataSource?.external + if (!external) return + + const dialect = external.dialect + let idQuery = '' + if (dialect === 'postgresql') { + idQuery = `SELECT "${pkCol}" FROM "${tableName}"` + } else { + idQuery = `SELECT \`${pkCol}\` FROM \`${tableName}\`` + } + + const extRows = await executeSDKQuery({ + sql: idQuery, + params: [], + dataSource: this.dataSource!, + config: this.config!, + }) + + if (!extRows) return + + const extIdsSet = new Set(extRows.map((r: any) => r[pkCol])) + + // Get local SQLite IDs + const localRows = (await this.dataSource!.rpc.executeQuery({ + sql: `SELECT "${pkCol}" FROM "${tableName}"`, + params: [], + })) as any[] + + if (!localRows || localRows.length === 0) return + + const idsToDelete = localRows + .map((r: any) => r[pkCol]) + .filter((id) => !extIdsSet.has(id)) + + if (idsToDelete.length > 0) { + console.log( + `Deleting ${idsToDelete.length} rows in local SQLite for table ${tableName} (hard deletes)` + ) + for (let i = 0; i < idsToDelete.length; i += 500) { + const chunk = idsToDelete.slice(i, i + 500) + const placeholders = chunk.map(() => '?').join(', ') + await this.dataSource!.rpc.executeQuery({ + sql: `DELETE FROM "${tableName}" WHERE "${pkCol}" IN (${placeholders})`, + params: chunk, + }) + } + } + } + + /** + * Maps database types to SQLite column types + */ + private mapToSQLiteType(externalType: string): string { + const type = externalType.toLowerCase() + if ( + type.includes('int') || + type.includes('bool') || + type.includes('boolean') + ) { + return 'INTEGER' + } + if ( + type.includes('char') || + type.includes('text') || + type.includes('uuid') || + type.includes('time') || + type.includes('date') || + type.includes('json') + ) { + return 'TEXT' + } + if ( + type.includes('double') || + type.includes('real') || + type.includes('numeric') || + type.includes('float') + ) { + return 'REAL' + } + return 'BLOB' + } +} diff --git a/src/index.ts b/src/index.ts index 4d08932..82d0aeb 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,6 +12,7 @@ import { QueryLogPlugin } from '../plugins/query-log' import { StatsPlugin } from '../plugins/stats' import { CronPlugin } from '../plugins/cron' import { InterfacePlugin } from '../plugins/interface' +import { ReplicationPlugin } from '../plugins/replication' export { StarbaseDBDurableObject } from './do' @@ -51,6 +52,10 @@ export interface Env { EXTERNAL_DB_CLOUDFLARE_ACCOUNT_ID?: string EXTERNAL_DB_CLOUDFLARE_DATABASE_ID?: string + EXTERNAL_DB_TABLES_TO_TRACK?: string + EXTERNAL_DB_POLLING_INTERVAL?: string + EXTERNAL_DB_BATCH_SIZE?: number + AUTH_ALGORITHM?: string AUTH_JWKS_ENDPOINT?: string @@ -209,6 +214,8 @@ export default { // Include cron event code here }, ctx) + const replicationPlugin = new ReplicationPlugin({ cronPlugin }) + const interfacePlugin = new InterfacePlugin() const plugins = [ @@ -224,6 +231,7 @@ export default { new QueryLogPlugin({ ctx }), cdcPlugin, cronPlugin, + replicationPlugin, new StatsPlugin(), interfacePlugin, ] satisfies StarbasePlugin[] diff --git a/src/rls/index.test.ts b/src/rls/index.test.ts index cf00156..76211ba 100644 --- a/src/rls/index.test.ts +++ b/src/rls/index.test.ts @@ -17,6 +17,10 @@ const mockConfig: StarbaseDBConfiguration = { features: { allowlist: true, rls: true, rest: true }, } +beforeEach(() => { + mockConfig.role = 'client' +}) + describe('loadPolicies - Policy Fetching and Parsing', () => { it('should load and parse policies correctly', async () => { vi.mocked(mockDataSource.rpc.executeQuery).mockResolvedValue([ @@ -74,7 +78,7 @@ describe('applyRLS - Query Modification', () => { mockDataSource.context.sub = 'user123' vi.mocked(mockDataSource.rpc.executeQuery).mockResolvedValue([ { - actions: 'SELECT', + actions: '*', schema: 'public', table: 'users', column: 'user_id', @@ -95,7 +99,7 @@ describe('applyRLS - Query Modification', () => { }) console.log('Final SQL:', modifiedSql) - expect(modifiedSql).toContain("WHERE `user_id` = 'user123'") + expect(modifiedSql).toContain("`public.users`.`user_id` = 'user123'") }) it('should modify DELETE queries by adding policy-based WHERE clause', async () => { const sql = "DELETE FROM users WHERE name = 'Alice'" @@ -106,7 +110,8 @@ describe('applyRLS - Query Modification', () => { config: mockConfig, }) - expect(modifiedSql).toContain("WHERE `name` = 'Alice'") + expect(modifiedSql).toContain("`name` = 'Alice'") + expect(modifiedSql).toContain("`public.users`.`user_id` = 'user123'") }) it('should modify UPDATE queries with additional WHERE clause', async () => { @@ -118,7 +123,9 @@ describe('applyRLS - Query Modification', () => { config: mockConfig, }) - expect(modifiedSql).toContain("`name` = 'Bob' WHERE `age` = 25") + expect(modifiedSql).toContain("`name` = 'Bob'") + expect(modifiedSql).toContain('`age` = 25') + expect(modifiedSql).toContain("`public.users`.`user_id` = 'user123'") }) it('should modify INSERT queries to enforce column values', async () => { @@ -130,7 +137,7 @@ describe('applyRLS - Query Modification', () => { config: mockConfig, }) - expect(modifiedSql).toContain("VALUES (1,'Alice')") + expect(modifiedSql).toContain("VALUES ('user123','Alice')") }) }) @@ -200,8 +207,8 @@ describe('applyRLS - Multi-Table Queries', () => { config: mockConfig, }) - expect(modifiedSql).toContain("WHERE `users.user_id` = 'user123'") - expect(modifiedSql).toContain("AND `orders.user_id` = 'user123'") + expect(modifiedSql).toContain("`public.users`.`user_id` = 'user123'") + expect(modifiedSql).toContain("`public.orders`.`user_id` = 'user123'") }) it('should apply RLS policies to multiple tables in a JOIN', async () => { @@ -218,8 +225,8 @@ describe('applyRLS - Multi-Table Queries', () => { config: mockConfig, }) - expect(modifiedSql).toContain("WHERE (users.user_id = 'user123')") - expect(modifiedSql).toContain("AND (orders.user_id = 'user123')") + expect(modifiedSql).toContain("(`public.users`.`user_id` = 'user123')") + expect(modifiedSql).toContain("(`public.orders`.`user_id` = 'user123')") }) it('should apply RLS policies to subqueries inside FROM clause', async () => { @@ -236,6 +243,6 @@ describe('applyRLS - Multi-Table Queries', () => { config: mockConfig, }) - expect(modifiedSql).toContain("WHERE `users.user_id` = 'user123'") + expect(modifiedSql).toContain("`public.users`.`user_id` = 'user123'") }) }) diff --git a/src/rls/index.ts b/src/rls/index.ts index 68abb4e..8d41ba1 100644 --- a/src/rls/index.ts +++ b/src/rls/index.ts @@ -249,36 +249,47 @@ function applyRLSToAst(ast: any): void { let tables: string[] = [] if (statementType === 'INSERT') { let tableName = normalizeIdentifier(ast.table[0].table) - if (tableName.includes('.')) { + if (tableName && tableName.includes('.')) { tableName = tableName.split('.')[1] } - tables = [tableName] + tables = tableName ? [tableName] : [] } else if (statementType === 'UPDATE') { - tables = ast.table.map((tableRef: any) => { - let tableName = normalizeIdentifier(tableRef.table) - if (tableName.includes('.')) { - tableName = tableName.split('.')[1] - } - return tableName - }) - } else { - // SELECT or DELETE - tables = - ast.from?.map((fromTable: any) => { - let tableName = normalizeIdentifier(fromTable.table) - if (tableName.includes('.')) { + tables = ast.table + .map((tableRef: any) => { + let tableName = normalizeIdentifier(tableRef.table) + if (tableName && tableName.includes('.')) { tableName = tableName.split('.')[1] } return tableName - }) || [] + }) + .filter(Boolean) + } else { + // SELECT or DELETE + tables = + ast.from + ?.map((fromTable: any) => { + let tableName = normalizeIdentifier(fromTable.table) + if (tableName && tableName.includes('.')) { + tableName = tableName.split('.')[1] + } + return tableName + }) + .filter(Boolean) || [] } const restrictedTables = Object.keys(tablesWithRules) for (const table of tables) { - if (restrictedTables.includes(table)) { - const allowedActions = tablesWithRules[table] - if (!allowedActions.includes(statementType)) { + const matchingKey = restrictedTables.find( + (rt) => + rt === table || (rt.includes('.') && rt.split('.')[1] === table) + ) + if (matchingKey) { + const allowedActions = tablesWithRules[matchingKey] + if ( + !allowedActions.includes(statementType) && + !allowedActions.includes('*') + ) { throw new Error( `Unauthorized access: No matching rules for ${statementType} on restricted table ${table}` ) @@ -292,11 +303,16 @@ function applyRLSToAst(ast: any): void { ) .forEach(({ action, condition }) => { const targetTable = normalizeIdentifier(condition.left.table) - const isTargetTable = tables.includes(targetTable) + const targetTableWithoutSchema = targetTable.includes('.') + ? targetTable.split('.')[1] + : targetTable + const isTargetTable = + tables.includes(targetTable) || + tables.includes(targetTableWithoutSchema) if (!isTargetTable) return - if (action !== 'INSERT') { + if (statementType !== 'INSERT') { // Add condition to WHERE with parentheses if (ast.where) { ast.where = { @@ -349,8 +365,15 @@ function applyRLSToAst(ast: any): void { }) ast.from?.forEach((fromItem: any) => { - if (fromItem.expr && fromItem.expr.type === 'select') { - applyRLSToAst(fromItem.expr) + if (fromItem.expr) { + if (fromItem.expr.type === 'select') { + applyRLSToAst(fromItem.expr) + } else if ( + fromItem.expr.ast && + fromItem.expr.ast.type === 'select' + ) { + applyRLSToAst(fromItem.expr.ast) + } } // Handle both single join and array of joins @@ -359,8 +382,15 @@ function applyRLSToAst(ast: any): void { ? fromItem.join : [fromItem] joins.forEach((joinItem: any) => { - if (joinItem.expr && joinItem.expr.type === 'select') { - applyRLSToAst(joinItem.expr) + if (joinItem.expr) { + if (joinItem.expr.type === 'select') { + applyRLSToAst(joinItem.expr) + } else if ( + joinItem.expr.ast && + joinItem.expr.ast.type === 'select' + ) { + applyRLSToAst(joinItem.expr.ast) + } } }) } @@ -371,8 +401,12 @@ function applyRLSToAst(ast: any): void { } ast.columns?.forEach((column: any) => { - if (column.expr && column.expr.type === 'select') { - applyRLSToAst(column.expr) + if (column.expr) { + if (column.expr.type === 'select') { + applyRLSToAst(column.expr) + } else if (column.expr.ast && column.expr.ast.type === 'select') { + applyRLSToAst(column.expr.ast) + } } }) } @@ -381,6 +415,8 @@ function traverseWhere(node: any): void { if (!node) return if (node.type === 'select') { applyRLSToAst(node) + } else if (node.ast && node.ast.type === 'select') { + applyRLSToAst(node.ast) } if (node.left) traverseWhere(node.left) if (node.right) traverseWhere(node.right) diff --git a/wrangler.toml b/wrangler.toml index 395c4ac..6481562 100644 --- a/wrangler.toml +++ b/wrangler.toml @@ -45,8 +45,8 @@ REGION = "auto" # Uncomment the section below to create a user for logging into your database UI. # You can access the Studio UI at: https://your_endpoint/studio -# STUDIO_USER = "admin" -# STUDIO_PASS = "123456" +STUDIO_USER = "admin" +STUDIO_PASS = "123456" # Toggle to enable default features ENABLE_ALLOWLIST = 0 @@ -63,6 +63,10 @@ ENABLE_RLS = 0 # EXTERNAL_DB_DATABASE = "" # EXTERNAL_DB_DEFAULT_SCHEMA = "public" +# EXTERNAL_DB_TABLES_TO_TRACK = "users,posts,comments" +# EXTERNAL_DB_POLLING_INTERVAL = "*/1 * * * *" +# EXTERNAL_DB_BATCH_SIZE = 500 + # EXTERNAL_DB_MONGODB_URI = "" # EXTERNAL_DB_TURSO_URI = "" # EXTERNAL_DB_TURSO_TOKEN = "" From 324863f94d343c0abd1c7c5a2f96a8f7aed45180 Mon Sep 17 00:00:00 2001 From: Rohan-R07 Date: Sat, 13 Jun 2026 11:55:39 +0530 Subject: [PATCH 2/2] fix: resolve typescript compilation errors in replication plugin --- plugins/replication/index.ts | 9 ++++++--- src/do.ts | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/plugins/replication/index.ts b/plugins/replication/index.ts index 47c47d0..5aff4c0 100644 --- a/plugins/replication/index.ts +++ b/plugins/replication/index.ts @@ -368,19 +368,22 @@ export class ReplicationPlugin extends StarbasePlugin { return { sql: insertSQL, params } }) - await this.dataSource!.rpc.executeTransaction(insertQueries, false) + await (this.dataSource!.rpc as any).executeTransaction( + insertQueries, + false + ) // 6. Update replication watermark state let nextSyncedId = lastSyncedId let nextSyncedAt = lastSyncedAt - const lastRow = rows[rows.length - 1] + const lastRow = rows[rows.length - 1] as any if (pkCol && lastRow[pkCol] !== undefined) { nextSyncedId = String(lastRow[pkCol]) } if (hasUpdatedAt && lastRow.updated_at !== undefined) { const val = lastRow.updated_at - nextSyncedAt = val ? new Date(val).toISOString() : null + nextSyncedAt = val ? new Date(val as any).toISOString() : null } await this.dataSource!.rpc.executeQuery({ diff --git a/src/do.ts b/src/do.ts index b6bb2b6..f8020e0 100644 --- a/src/do.ts +++ b/src/do.ts @@ -72,6 +72,7 @@ export class StarbaseDBDurableObject extends DurableObject { deleteAlarm: this.deleteAlarm.bind(this), getStatistics: this.getStatistics.bind(this), executeQuery: this.executeQuery.bind(this), + executeTransaction: this.executeTransaction.bind(this), } }