diff --git a/Cargo.lock b/Cargo.lock index f033ae8f..00ed243e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,12 +59,6 @@ dependencies = [ "equator", ] -[[package]] -name = "allocator-api2" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" - [[package]] name = "android-tzdata" version = "0.1.1" @@ -136,12 +130,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "array-init" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d62b7694a562cdf5a74227903507c56ab2cc8bdd1f781ed5cb4cf9c9f810bfc" - [[package]] name = "arrayvec" version = "0.7.6" @@ -165,30 +153,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" -[[package]] -name = "aws-lc-rs" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08b5d4e069cbc868041a64bd68dc8cb39a0d79585cd6c5a24caa8c2d622121be" -dependencies = [ - "aws-lc-sys", - "untrusted 0.7.1", - "zeroize", -] - -[[package]] -name = "aws-lc-sys" -version = "0.30.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff" -dependencies = [ - "bindgen", - "cc", - "cmake", - "dunce", - "fs_extra", -] - [[package]] name = "backtrace" version = "0.3.75" @@ -210,21 +174,6 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" -[[package]] -name = "base64" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" - -[[package]] -name = "bincode" -version = "1.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" -dependencies = [ - "serde", -] - [[package]] name = "bindgen" version = "0.69.5" @@ -237,15 +186,12 @@ dependencies = [ "itertools 0.12.1", "lazy_static", "lazycell", - "log", - "prettyplease", "proc-macro2", "quote", "regex", "rustc-hash", "shlex", "syn 2.0.104", - "which", ] [[package]] @@ -309,9 +255,6 @@ name = "bumpalo" version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" -dependencies = [ - "allocator-api2", -] [[package]] name = "bytecheck" @@ -508,15 +451,6 @@ dependencies = [ "error-code", ] -[[package]] -name = "cmake" -version = "0.1.54" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" -dependencies = [ - "cc", -] - [[package]] name = "colorchoice" version = "1.0.4" @@ -561,15 +495,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "cpufeatures" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" -dependencies = [ - "libc", -] - [[package]] name = "criterion" version = "0.5.1" @@ -712,17 +637,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "derive-new" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.104", -] - [[package]] name = "digest" version = "0.10.7" @@ -731,15 +645,8 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", - "subtle", ] -[[package]] -name = "dunce" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" - [[package]] name = "educe" version = "0.4.23" @@ -777,29 +684,6 @@ dependencies = [ "syn 2.0.104", ] -[[package]] -name = "env_filter" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" -dependencies = [ - "log", - "regex", -] - -[[package]] -name = "env_logger" -version = "0.11.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" -dependencies = [ - "anstream", - "anstyle", - "env_filter", - "jiff", - "log", -] - [[package]] name = "equator" version = "0.4.2" @@ -842,12 +726,6 @@ version = "3.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" -[[package]] -name = "fallible-iterator" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" - [[package]] name = "fastrand" version = "2.3.0" @@ -887,12 +765,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "fs_extra" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" - [[package]] name = "funty" version = "2.0.0" @@ -1074,30 +946,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - -[[package]] -name = "home" -version = "0.5.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" -dependencies = [ - "windows-sys 0.59.0", -] - [[package]] name = "humantime" version = "2.2.0" @@ -1184,17 +1032,6 @@ dependencies = [ "str_stack", ] -[[package]] -name = "io-uring" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" -dependencies = [ - "bitflags 2.9.1", - "cfg-if", - "libc", -] - [[package]] name = "is-terminal" version = "0.4.16" @@ -1236,30 +1073,6 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" -[[package]] -name = "jiff" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be1f93b8b1eb69c77f24bbb0afdf66f54b632ee39af40ca21c4365a1d7347e49" -dependencies = [ - "jiff-static", - "log", - "portable-atomic", - "portable-atomic-util", - "serde", -] - -[[package]] -name = "jiff-static" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.104", -] - [[package]] name = "jobserver" version = "0.1.33" @@ -1282,22 +1095,17 @@ dependencies = [ [[package]] name = "kite_sql" -version = "0.2.1" +version = "0.3.0" dependencies = [ "ahash 0.8.12", - "async-trait", - "base64 0.21.7", - "bincode", + "base64", "bumpalo", "byteorder", "chrono", - "clap", "comfy-table", "criterion", "csv", - "env_logger", "fixedbitset", - "futures", "getrandom 0.2.16", "getrandom 0.3.3", "indicatif", @@ -1307,26 +1115,19 @@ dependencies = [ "librocksdb-sys", "lmdb", "lmdb-sys", - "log", "once_cell", "ordered-float", - "parking_lot", "paste", - "pgwire", "pprof", "pyo3", "recursive", "rocksdb", "rust_decimal", "rustyline", - "serde", - "serde-wasm-bindgen", "siphasher", "sqlite", "sqlparser", "tempfile", - "thiserror 1.0.69", - "tokio", "ulid", "wasm-bindgen", "wasm-bindgen-test", @@ -1343,29 +1144,6 @@ dependencies = [ "syn 2.0.104", ] -[[package]] -name = "lazy-regex" -version = "3.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60c7310b93682b36b98fa7ea4de998d3463ccbebd94d935d6b48ba5b6ffa7126" -dependencies = [ - "lazy-regex-proc_macros", - "once_cell", - "regex-lite", -] - -[[package]] -name = "lazy-regex-proc_macros" -version = "3.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ba01db5ef81e17eb10a5e0f2109d1b3a3e29bac3070fdbd7d156bf7dbd206a1" -dependencies = [ - "proc-macro2", - "quote", - "regex", - "syn 2.0.104", -] - [[package]] name = "lazy_static" version = "1.5.0" @@ -1435,12 +1213,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "linux-raw-sys" -version = "0.4.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" - [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -1492,7 +1264,6 @@ dependencies = [ "kite_sql", "lazy_static", "rust_decimal", - "serde", "sqlparser", "tempfile", ] @@ -1507,12 +1278,6 @@ dependencies = [ "digest", ] -[[package]] -name = "md5" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" - [[package]] name = "memchr" version = "2.7.5" @@ -1562,17 +1327,6 @@ dependencies = [ "adler2", ] -[[package]] -name = "mio" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" -dependencies = [ - "libc", - "wasi 0.11.1+wasi-snapshot-preview1", - "windows-sys 0.59.0", -] - [[package]] name = "nix" version = "0.26.4" @@ -1704,8 +1458,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" dependencies = [ "num-traits", - "rand 0.8.5", - "serde", ] [[package]] @@ -1714,59 +1466,12 @@ version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" -[[package]] -name = "parking_lot" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets 0.52.6", -] - [[package]] name = "paste" version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" -[[package]] -name = "pgwire" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c84e671791f3a354f265e55e400be8bb4b6262c1ec04fac4289e710ccf22ab43" -dependencies = [ - "async-trait", - "aws-lc-rs", - "bytes", - "chrono", - "derive-new", - "futures", - "hex", - "lazy-regex", - "md5", - "postgres-types", - "rand 0.8.5", - "rust_decimal", - "thiserror 2.0.12", - "tokio", - "tokio-rustls", - "tokio-util", -] - [[package]] name = "pin-project-lite" version = "0.2.16" @@ -1819,46 +1524,6 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" -[[package]] -name = "portable-atomic-util" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" -dependencies = [ - "portable-atomic", -] - -[[package]] -name = "postgres-protocol" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76ff0abab4a9b844b93ef7b81f1efc0a366062aaef2cd702c76256b5dc075c54" -dependencies = [ - "base64 0.22.1", - "byteorder", - "bytes", - "fallible-iterator", - "hmac", - "md-5", - "memchr", - "rand 0.9.1", - "sha2", - "stringprep", -] - -[[package]] -name = "postgres-types" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" -dependencies = [ - "array-init", - "bytes", - "chrono", - "fallible-iterator", - "postgres-protocol", -] - [[package]] name = "pprof" version = "0.15.0" @@ -1891,16 +1556,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "prettyplease" -version = "0.2.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" -dependencies = [ - "proc-macro2", - "syn 2.0.104", -] - [[package]] name = "proc-macro-crate" version = "3.3.0" @@ -2050,7 +1705,6 @@ dependencies = [ "libc", "rand_chacha 0.3.1", "rand_core 0.6.4", - "serde", ] [[package]] @@ -2090,7 +1744,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ "getrandom 0.2.16", - "serde", ] [[package]] @@ -2142,15 +1795,6 @@ dependencies = [ "syn 2.0.104", ] -[[package]] -name = "redox_syscall" -version = "0.5.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d04b7d0ee6b4a0207a0a7adb104d23ecb0b47d6beae7152d0fa34b692b29fd6" -dependencies = [ - "bitflags 2.9.1", -] - [[package]] name = "regex" version = "1.11.1" @@ -2174,12 +1818,6 @@ dependencies = [ "regex-syntax", ] -[[package]] -name = "regex-lite" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -2204,20 +1842,6 @@ dependencies = [ "bytemuck", ] -[[package]] -name = "ring" -version = "0.17.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" -dependencies = [ - "cc", - "cfg-if", - "getrandom 0.2.16", - "libc", - "untrusted 0.9.0", - "windows-sys 0.52.0", -] - [[package]] name = "rkyv" version = "0.7.45" @@ -2267,7 +1891,6 @@ dependencies = [ "borsh", "bytes", "num-traits", - "postgres-types", "rand 0.8.5", "rkyv", "serde", @@ -2286,19 +1909,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" -[[package]] -name = "rustix" -version = "0.38.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" -dependencies = [ - "bitflags 2.9.1", - "errno", - "libc", - "linux-raw-sys 0.4.15", - "windows-sys 0.59.0", -] - [[package]] name = "rustix" version = "1.0.8" @@ -2308,46 +1918,10 @@ dependencies = [ "bitflags 2.9.1", "errno", "libc", - "linux-raw-sys 0.9.4", + "linux-raw-sys", "windows-sys 0.60.2", ] -[[package]] -name = "rustls" -version = "0.23.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1" -dependencies = [ - "aws-lc-rs", - "log", - "once_cell", - "rustls-pki-types", - "rustls-webpki", - "subtle", - "zeroize", -] - -[[package]] -name = "rustls-pki-types" -version = "1.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" -dependencies = [ - "zeroize", -] - -[[package]] -name = "rustls-webpki" -version = "0.103.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" -dependencies = [ - "aws-lc-rs", - "ring", - "rustls-pki-types", - "untrusted 0.9.0", -] - [[package]] name = "rustversion" version = "1.0.21" @@ -2409,17 +1983,6 @@ dependencies = [ "serde_derive", ] -[[package]] -name = "serde-wasm-bindgen" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b" -dependencies = [ - "js-sys", - "serde", - "wasm-bindgen", -] - [[package]] name = "serde_derive" version = "1.0.219" @@ -2443,32 +2006,12 @@ dependencies = [ "serde", ] -[[package]] -name = "sha2" -version = "0.10.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "shlex" version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "signal-hook-registry" -version = "1.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" -dependencies = [ - "libc", -] - [[package]] name = "simdutf8" version = "0.1.5" @@ -2486,9 +2029,6 @@ name = "siphasher" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" -dependencies = [ - "serde", -] [[package]] name = "slab" @@ -2502,16 +2042,6 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" -[[package]] -name = "socket2" -version = "0.5.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "spin" version = "0.10.0" @@ -2619,29 +2149,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb" -[[package]] -name = "stringprep" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" -dependencies = [ - "unicode-bidi", - "unicode-normalization", - "unicode-properties", -] - [[package]] name = "strsim" version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" -[[package]] -name = "subtle" -version = "2.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" - [[package]] name = "symbolic-common" version = "12.15.5" @@ -2708,7 +2221,7 @@ dependencies = [ "fastrand", "getrandom 0.3.3", "once_cell", - "rustix 1.0.8", + "rustix", "windows-sys 0.59.0", ] @@ -2795,60 +2308,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" -[[package]] -name = "tokio" -version = "1.46.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" -dependencies = [ - "backtrace", - "bytes", - "io-uring", - "libc", - "mio", - "parking_lot", - "pin-project-lite", - "signal-hook-registry", - "slab", - "socket2", - "tokio-macros", - "windows-sys 0.52.0", -] - -[[package]] -name = "tokio-macros" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.104", -] - -[[package]] -name = "tokio-rustls" -version = "0.26.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" -dependencies = [ - "rustls", - "tokio", -] - -[[package]] -name = "tokio-util" -version = "0.7.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "pin-project-lite", - "tokio", -] - [[package]] name = "toml_datetime" version = "0.6.11" @@ -2879,7 +2338,6 @@ dependencies = [ "rand 0.8.5", "rust_decimal", "sqlite", - "thiserror 1.0.69", ] [[package]] @@ -2926,37 +2384,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "470dbf6591da1b39d43c14523b2b469c86879a53e8b758c8e090a470fe7b1fbe" dependencies = [ "rand 0.9.1", - "serde", "web-time", ] -[[package]] -name = "unicode-bidi" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" - [[package]] name = "unicode-ident" version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" -[[package]] -name = "unicode-normalization" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" -dependencies = [ - "tinyvec", -] - -[[package]] -name = "unicode-properties" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" - [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -2981,18 +2417,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" -[[package]] -name = "untrusted" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" - -[[package]] -name = "untrusted" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" - [[package]] name = "utf8parse" version = "0.2.2" @@ -3156,18 +2580,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix 0.38.44", -] - [[package]] name = "winapi" version = "0.3.9" @@ -3459,9 +2871,3 @@ dependencies = [ "quote", "syn 2.0.104", ] - -[[package]] -name = "zeroize" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" diff --git a/Cargo.toml b/Cargo.toml index 2697c842..138c68ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ [package] name = "kite_sql" -version = "0.2.1" +version = "0.3.0" edition = "2021" build = "build.rs" authors = ["Kould ", "Xwg "] @@ -13,32 +13,30 @@ repository = "https://github.com/KipData/KiteSQL" readme = "README.md" keywords = ["sql", "sqlite", "database", "mysql"] categories = ["development-tools", "database"] -default-run = "kite_sql" - -[[bin]] -name = "kite_sql" -path = "src/bin/server.rs" -required-features = ["net", "rocksdb"] [[bin]] name = "kitesql-shell" path = "src/bin/shell.rs" -required-features = ["rocksdb"] +required-features = ["rocksdb", "shell"] [lib] doctest = false crate-type = ["cdylib", "rlib"] [features] -default = ["macros", "rocksdb"] +default = ["time", "macros", "parser", "rocksdb"] +time = ["dep:chrono"] +copy = ["dep:csv"] +decimal = ["dep:rust_decimal"] macros = [] -orm = [] +orm = ["macros"] +parser = ["dep:sqlparser"] rocksdb = ["dep:rocksdb"] unsafe_txdb_checkpoint = ["rocksdb", "dep:librocksdb-sys"] lmdb = ["dep:lmdb", "dep:lmdb-sys"] -net = ["rocksdb", "dep:pgwire", "dep:async-trait", "dep:clap", "dep:env_logger", "dep:futures", "dep:log", "dep:tokio"] pprof = ["pprof/criterion", "pprof/flamegraph"] -python = ["dep:pyo3"] +python = ["parser", "dep:pyo3"] +shell = ["parser", "dep:comfy-table", "dep:rustyline"] [[bench]] name = "query_bench" @@ -48,36 +46,24 @@ required-features = ["pprof"] [dependencies] ahash = { version = "0.8" } -bincode = { version = "1" } -bumpalo = { version = "3", features = ["allocator-api2", "collections", "std"] } +bumpalo = { version = "3", default-features = false, features = ["collections"] } byteorder = { version = "1" } -chrono = { version = "0.4" } -comfy-table = { version = "7", default-features = false } -csv = { version = "1" } fixedbitset = { version = "0.4" } itertools = { version = "0.12" } -ordered-float = { version = "4", features = ["serde"] } +ordered-float = { version = "4" } paste = { version = "1" } -parking_lot = { version = "0.12", features = ["arc_lock"] } -pyo3 = { version = "0.23", features = ["auto-initialize"], optional = true } recursive = { version = "0.1" } -rust_decimal = { version = "1" } -serde = { version = "1", features = ["derive", "rc"] } kite_sql_serde_macros = { version = "0.2.0", path = "kite_sql_serde_macros" } -siphasher = { version = "1", features = ["serde"] } -sqlparser = { version = "0.61", default-features = false, features = ["std"] } -thiserror = { version = "1" } -ulid = { version = "1", features = ["serde"] } - -# Feature: net -async-trait = { version = "0.1", optional = true } -clap = { version = "4.5", features = ["derive"], optional = true } -env_logger = { version = "0.11", optional = true } -futures = { version = "0.3", optional = true } -log = { version = "0.4", optional = true } -pgwire = { version = "0.28.0", optional = true } -tokio = { version = "1.36", features = ["full"], optional = true } +siphasher = { version = "1" } +ulid = { version = "1" } +# Optional dependencies for features +comfy-table = { version = "7", default-features = false, optional = true } +chrono = { version = "0.4", optional = true } +csv = { version = "1", optional = true } +pyo3 = { version = "0.23", features = ["auto-initialize"], optional = true } +rust_decimal = { version = "1", default-features = false, features = ["std"], optional = true } +sqlparser = { version = "0.61", default-features = false, features = ["std"], optional = true } [target.'cfg(unix)'.dev-dependencies] pprof = { version = "0.15", features = ["flamegraph", "criterion"] } @@ -90,11 +76,12 @@ tempfile = { version = "3.10" } sqlite = { version = "0.34" } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -rocksdb = { version = "0.23", optional = true, default-features = false, features = ["bindgen-runtime"] } +# Optional dependencies for features librocksdb-sys = { version = "0.17.1", optional = true } lmdb = { version = "0.8.0", optional = true } lmdb-sys = { version = "0.8.0", optional = true } -rustyline = { version = "14", default-features = false } +rocksdb = { version = "0.23", optional = true, default-features = false, features = ["bindgen-runtime"] } +rustyline = { version = "14", default-features = false, optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen = { version = "0.2.106" } @@ -106,7 +93,6 @@ base64 = { version = "0.21" } getrandom = { version = "0.2", features = ["js"] } getrandom_03 = { package = "getrandom", version = "0.3", features = ["wasm_js"] } js-sys = { version = "0.3.83" } -serde-wasm-bindgen = { version = "0.6.5" } once_cell = { version = "1" } [target.'cfg(target_arch = "wasm32")'.dev-dependencies] diff --git a/Makefile b/Makefile index 9616bfb2..5b0de543 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ test: ## Run Python binding API tests implemented with pyo3. test-python: - PYO3_PYTHON=$(PYO3_PYTHON) $(CARGO) test --features python test_python_ + PYO3_PYTHON=$(PYO3_PYTHON) $(CARGO) test --features python,decimal test_python_ ## Perform a `cargo check` across the workspace. cargo-check: diff --git a/README.md b/README.md index 3def9381..91567183 100755 --- a/README.md +++ b/README.md @@ -55,7 +55,8 @@ For the full ORM guide, see [`src/orm/README.md`](src/orm/README.md). ```rust use kite_sql::db::DataBaseBuilder; use kite_sql::errors::DatabaseError; -use kite_sql::{Model, Projection}; +use kite_sql::orm::OrmQueryResultExt; +use kite_sql::Model; #[derive(Default, Debug, PartialEq, Model)] #[model(table = "users")] @@ -71,15 +72,8 @@ struct User { age: Option, } -#[derive(Default, Debug, PartialEq, Projection)] -struct UserSummary { - id: i32, - #[projection(rename = "user_name")] - display_name: String, -} - fn main() -> Result<(), DatabaseError> { - let database = DataBaseBuilder::path("./data").build_rocksdb()?; + let mut database = DataBaseBuilder::path("./data").build_rocksdb()?; // Or: let database = DataBaseBuilder::path("./data").build_lmdb()?; database.migrate::()?; @@ -100,24 +94,31 @@ fn main() -> Result<(), DatabaseError> { ])?; database - .from::() - .eq(User::id(), 1) - .update() - .set(User::age(), Some(19)) - .execute()?; + .bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(User::id())?.eq(1))? + .update(|u| u.set_value(User::age(), Some(19))) + })? + .done()?; database - .from::() - .eq(User::id(), 2) - .delete()?; + .bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(User::id())?.eq(2))? + .delete() + })? + .done()?; let users = database - .from::() - .gte(User::age(), 18) - .project::() - .asc(User::name()) - .limit(10) - .fetch()?; + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::age())?.gte(18))? + .project_scalars((User::id(), User::name()))? + .order_by(User::name())? + .limit(10)? + .finish() + })? + .project_tuple::<(i32, String)>(); for user in users { println!("{:?}", user?); @@ -138,6 +139,7 @@ fn main() -> Result<(), DatabaseError> { - Transaction isolation is documented in [`docs/transaction-isolation.md`](docs/transaction-isolation.md). - Cargo features: - `rocksdb` is enabled by default + - `parser` is enabled by default and provides the SQL parser frontend - `lmdb` is optional - `unsafe_txdb_checkpoint` enables experimental checkpoint support for RocksDB `TransactionDB` - `cargo check --no-default-features --features lmdb` builds an LMDB-only native configuration @@ -164,7 +166,7 @@ Checkpoint support and feature-gating details are documented in [docs/features.m import { WasmDatabase } from "./pkg/kite_sql.js"; const db = new WasmDatabase(); -await db.execute("create table demo(id int primary key, v int)"); +await db.ddl("create table demo(id int primary key, v int)"); await db.execute("insert into demo values (1, 2), (2, 4)"); const rows = db.run("select * from demo").rows(); console.log(rows.map((r) => r.values.map((v) => v.Int32 ?? v))); @@ -198,10 +200,10 @@ Recent 720-second local comparison on the machine above: | Backend | TpmC | New-Order p90 | Payment p90 | Order-Status p90 | Delivery p90 | Stock-Level p90 | | --- | ---: | ---: | ---: | ---: | ---: | ---: | -| KiteSQL LMDB | 68394 | 0.001s | 0.001s | 0.001s | 0.002s | 0.001s | -| KiteSQL RocksDB | 30387 | 0.001s | 0.001s | 0.001s | 0.015s | 0.002s | -| SQLite balanced | 41690 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | -| SQLite practical | 38861 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | +| KiteSQL LMDB | 61723 | 0.001s | 0.001s | 0.001s | 0.002s | 0.001s | +| KiteSQL RocksDB | 30446 | 0.001s | 0.001s | 0.001s | 0.016s | 0.002s | +| SQLite balanced | 42989 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | +| SQLite practical | 42276 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | The detailed raw outputs are recorded in [tpcc/README.md](tpcc/README.md). #### 👉[check more](tpcc/README.md) diff --git a/examples/hello_world.rs b/examples/hello_world.rs index 7b7a5bd0..7b2dd302 100644 --- a/examples/hello_world.rs +++ b/examples/hello_world.rs @@ -47,7 +47,7 @@ mod app { pub c2: String, } - fn run_with_database(database: Database) -> Result<(), DatabaseError> { + fn run_with_database(mut database: Database) -> Result<(), DatabaseError> { database.create_table_if_not_exists::()?; database.insert(&MyStruct { c1: 0, @@ -63,22 +63,25 @@ mod app { })?; database - .from::() - .eq(MyStruct::c1(), 1) - .update() - .set(MyStruct::c2(), "ONE") - .execute()?; - database.from::().eq(MyStruct::c1(), 2).delete()?; + .bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(MyStruct::c1())?.eq(1))? + .update(|u| u.set_value(MyStruct::c2(), "ONE")) + })? + .done()?; + database + .bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(MyStruct::c1())?.eq(2))? + .delete() + })? + .done()?; for row in database.fetch::()? { println!("{:?}", row?); } - let mut agg = database.run("select count(*) from my_struct")?; - if let Some(count_row) = agg.next() { - println!("row count = {:?}", count_row?); - } - agg.done()?; + println!("row count = {}", database.fetch::()?.count()); database.drop_table::()?; @@ -87,23 +90,39 @@ mod app { pub fn run() -> Result<(), DatabaseError> { reset_example_dir()?; - let backend = env::var("KITESQL_BACKEND").unwrap_or_else(|_| "rocksdb".to_string()); + let backend = env::var("KITESQL_BACKEND").unwrap_or_else(|_| { + #[cfg(feature = "rocksdb")] + { + "rocksdb".to_string() + } + #[cfg(all(not(feature = "rocksdb"), feature = "lmdb"))] + { + "lmdb".to_string() + } + #[cfg(all(not(feature = "rocksdb"), not(feature = "lmdb")))] + { + "memory".to_string() + } + }); match backend.to_ascii_lowercase().as_str() { + "memory" => { + run_with_database(DataBaseBuilder::path(EXAMPLE_DB_PATH).build_in_memory()?) + } #[cfg(feature = "rocksdb")] "rocksdb" => run_with_database(DataBaseBuilder::path(EXAMPLE_DB_PATH).build_rocksdb()?), #[cfg(feature = "lmdb")] "lmdb" => run_with_database(DataBaseBuilder::path(EXAMPLE_DB_PATH).build_lmdb()?), other => Err(DatabaseError::InvalidValue(format!( "unsupported example backend '{other}', expected {}", - { - let mut expected = Vec::new(); + [ + "memory", #[cfg(feature = "rocksdb")] - expected.push("rocksdb"); + "rocksdb", #[cfg(feature = "lmdb")] - expected.push("lmdb"); - expected.join(" or ") - } + "lmdb", + ] + .join(" or ") ))), } } diff --git a/examples/transaction.rs b/examples/transaction.rs index 6e0db24e..9e145307 100644 --- a/examples/transaction.rs +++ b/examples/transaction.rs @@ -41,10 +41,8 @@ mod app { pub fn run() -> Result<(), DatabaseError> { reset_example_dir()?; // Optimistic transactions are currently backed by RocksDB. - let database = DataBaseBuilder::path(EXAMPLE_DB_PATH).build_optimistic()?; - database - .run("create table if not exists t1 (c1 int primary key, c2 int)")? - .done()?; + let mut database = DataBaseBuilder::path(EXAMPLE_DB_PATH).build_optimistic()?; + database.ddl("create table if not exists t1 (c1 int primary key, c2 int)")?; let mut transaction = database.new_transaction()?; transaction @@ -65,6 +63,7 @@ mod app { Tuple::new(None, vec![DataValue::Int32(1), DataValue::Int32(1)]) ); assert!(iter.next().is_none()); + iter.done()?; let mut tx2 = database.new_transaction()?; tx2.run("update t1 set c2 = 99 where c1 = 0")?.done()?; @@ -79,7 +78,7 @@ mod app { ); drop(tx2); - database.run("drop table t1")?.done()?; + database.ddl("drop table t1")?; Ok(()) } diff --git a/examples/wasm_hello_world.test.mjs b/examples/wasm_hello_world.test.mjs index 64a2db61..e8e347f1 100644 --- a/examples/wasm_hello_world.test.mjs +++ b/examples/wasm_hello_world.test.mjs @@ -8,8 +8,8 @@ const { WasmDatabase } = require("../pkg/kite_sql.js"); async function main() { const db = new WasmDatabase(); - await db.execute("drop table if exists my_struct"); - await db.execute("create table my_struct (c1 int primary key, c2 int)"); + await db.ddl("drop table if exists my_struct"); + await db.ddl("create table my_struct (c1 int primary key, c2 int)"); await db.execute("insert into my_struct values(0, 0), (1, 1)"); const iter = db.run("select * from my_struct"); @@ -48,7 +48,7 @@ async function main() { [1, 11], ]); - await db.execute("drop table my_struct"); + await db.ddl("drop table my_struct"); console.log("wasm hello_world test passed"); } diff --git a/examples/wasm_index_usage.test.mjs b/examples/wasm_index_usage.test.mjs index 8ba437f8..65587dd5 100644 --- a/examples/wasm_index_usage.test.mjs +++ b/examples/wasm_index_usage.test.mjs @@ -38,8 +38,8 @@ const { WasmDatabase } = require("../pkg/kite_sql.js"); async function main() { const db = new WasmDatabase(); - await db.execute("drop table if exists t1"); - await db.execute("create table t1(id int primary key, c1 int, c2 int)"); + await db.ddl("drop table if exists t1"); + await db.ddl("create table t1(id int primary key, c1 int, c2 int)"); // Insert data in bulk (20k rows) without reading from disk. // Each row matches the old CSV pattern: id = i*3, c1 = i*3+1, c2 = i*3+2. @@ -51,10 +51,10 @@ async function main() { } // Add indexes and analyze - await db.execute("create unique index u_c1_index on t1 (c1)"); - await db.execute("create index c2_index on t1 (c2)"); - await db.execute("create index p_index on t1 (c1, c2)"); - await db.execute("analyze table t1"); + await db.ddl("create unique index u_c1_index on t1 (c1)"); + await db.ddl("create index c2_index on t1 (c2)"); + await db.ddl("create index p_index on t1 (c1, c2)"); + await db.analyze("t1"); const rowVals = (row) => { const ints = row.values.map((v) => v.Int32 ?? v); @@ -101,7 +101,7 @@ async function main() { const afterDelete = db.run("select * from t1 where c2 = 123456").rows().map(rowVals); assert.equal(afterDelete.length, 0); - await db.execute("drop table t1"); + await db.ddl("drop table t1"); console.log("wasm index usage test passed"); } diff --git a/kite_sql_serde_macros/src/orm.rs b/kite_sql_serde_macros/src/orm.rs index 40b04056..a4fe59fc 100644 --- a/kite_sql_serde_macros/src/orm.rs +++ b/kite_sql_serde_macros/src/orm.rs @@ -12,12 +12,12 @@ struct OrmOpts { generics: Generics, table: Option, #[darling(default, multiple, rename = "index")] - indexes: Vec, + indexes: Vec, data: Data<(), OrmFieldOpts>, } #[derive(Debug, FromMeta)] -struct OrmIndexOpts { +struct ModelIndexOpts { name: String, columns: String, #[darling(default)] @@ -36,7 +36,7 @@ struct OrmFieldOpts { decimal_precision: Option, decimal_scale: Option, #[darling(rename = "default")] - default_expr: Option, + default_literal: Option, #[darling(default)] skip: bool, #[darling(default)] @@ -73,8 +73,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { let mut field_getters = Vec::new(); let mut column_names = Vec::new(); let mut placeholder_names = Vec::new(); - let mut create_index_statements = Vec::new(); - let mut create_index_if_not_exists_statements = Vec::new(); + let mut orm_indexes = Vec::new(); let mut persisted_columns = Vec::new(); let mut index_names = BTreeSet::new(); index_names.insert("pk_index".to_string()); @@ -121,7 +120,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { "char field cannot be skipped", )); } - if field.default_expr.is_some() { + if field.default_literal.is_some() { return Err(Error::new_spanned( field_name, "default field cannot be skipped", @@ -152,8 +151,8 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { "decimal_scale requires decimal_precision", )); } - let default_expr = field - .default_expr + let default_literal = field + .default_literal .map(|value| LitStr::new(&value, Span::call_site())); let field_name_string = field_name.to_string(); let column_name = field.rename.unwrap_or_else(|| field_name_string.clone()); @@ -163,6 +162,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { let is_primary_key = field.primary_key; let is_unique = field.unique; let is_index = field.index; + let column_index = orm_columns.len(); persisted_columns.push((field_name_string, column_name.clone())); column_names.push(column_name.clone()); @@ -204,23 +204,23 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { .push(parse_quote!(#field_ty : ::kite_sql::orm::DecimalType)); } - let ddl_type = if let Some(varchar_len) = varchar_len { - quote! { ::std::format!("varchar({})", #varchar_len) } + let data_type = if let Some(varchar_len) = varchar_len { + quote! { ::kite_sql::types::LogicalType::Varchar(Some(#varchar_len), ::kite_sql::types::CharLengthUnits::Characters) } } else if let Some(char_len) = char_len { - quote! { ::std::format!("char({})", #char_len) } + quote! { ::kite_sql::types::LogicalType::Char(#char_len, ::kite_sql::types::CharLengthUnits::Characters) } } else if let Some(decimal_precision) = decimal_precision { if let Some(decimal_scale) = decimal_scale { - quote! { ::std::format!("decimal({}, {})", #decimal_precision, #decimal_scale) } + quote! { ::kite_sql::types::LogicalType::Decimal(Some(#decimal_precision), Some(#decimal_scale)) } } else { - quote! { ::std::format!("decimal({})", #decimal_precision) } + quote! { ::kite_sql::types::LogicalType::Decimal(Some(#decimal_precision), None) } } } else { - quote! { <#field_ty as ::kite_sql::orm::ModelColumnType>::ddl_type() } + quote! { <#field_ty as ::kite_sql::orm::ModelColumnType>::logical_type() } }; - let default_expr_tokens = if let Some(default_expr) = &default_expr { - quote! { Some(#default_expr) } + let default_tokens = if let Some(default_literal) = &default_literal { + quote! { Some(::kite_sql::types::value::DataValue::from(#default_literal.to_string())) } } else { - quote! { None } + quote! { None::<::kite_sql::types::value::DataValue> } }; assignments.push(quote! { @@ -234,6 +234,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { orm_fields.push(quote! { ::kite_sql::orm::OrmField { column: #column_name_lit, + column_index: #column_index, placeholder: #placeholder_lit, primary_key: #is_primary_key, unique: #is_unique, @@ -246,13 +247,32 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { } }); orm_columns.push(quote! { - ::kite_sql::orm::OrmColumn { - name: #column_name_lit, - ddl_type: #ddl_type, - nullable: <#field_ty as ::kite_sql::orm::ModelColumnType>::nullable(), - primary_key: #is_primary_key, - unique: #is_unique, - default_expr: #default_expr_tokens, + { + let data_type = #data_type; + let default = #default_tokens + .map(|value| { + ::kite_sql::expression::ScalarExpression::Constant( + value + .cast(&data_type) + .expect("failed to cast ORM default value to column type"), + ) + }); + let desc = ::kite_sql::catalog::column::ColumnDesc::new( + data_type, + #is_primary_key.then_some(#column_index), + #is_unique, + default, + ) + .expect("failed to build ORM column descriptor"); + ::kite_sql::catalog::column::ColumnCatalog::new( + #column_name_lit.to_string(), + if #is_primary_key { + false + } else { + <#field_ty as ::kite_sql::orm::ModelColumnType>::nullable() + }, + desc, + ) } }); if is_unique { @@ -274,23 +294,8 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { format!("duplicate ORM index name: {index_name}"), )); } - create_index_statements.push(quote! { - ::kite_sql::orm::orm_create_index_statement( - #table_name_lit, - #index_name_lit, - &[#column_name_for_index], - false, - false, - ) - }); - create_index_if_not_exists_statements.push(quote! { - ::kite_sql::orm::orm_create_index_statement( - #table_name_lit, - #index_name_lit, - &[#column_name_for_index], - false, - true, - ) + orm_indexes.push(quote! { + (#index_name_lit, &[#column_name_for_index], false) }); } } @@ -358,23 +363,8 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { .map(|column| LitStr::new(column, Span::call_site())) .collect::>(); let is_unique = index.unique; - create_index_statements.push(quote! { - ::kite_sql::orm::orm_create_index_statement( - #table_name_lit, - #index_name_lit, - &[#(#index_columns),*], - #is_unique, - false, - ) - }); - create_index_if_not_exists_statements.push(quote! { - ::kite_sql::orm::orm_create_index_statement( - #table_name_lit, - #index_name_lit, - &[#(#index_columns),*], - #is_unique, - true, - ) + orm_indexes.push(quote! { + (#index_name_lit, &[#(#index_columns),*], #is_unique) }); } @@ -389,14 +379,18 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { let primary_key_value = primary_key_value.expect("primary key checked above"); let _primary_key_column = primary_key_column.expect("primary key checked above"); let _primary_key_placeholder = primary_key_placeholder.expect("primary key checked above"); + let mut from_generics = generics.clone(); + from_generics.params.insert(0, parse_quote!('__kite_arena)); + from_generics.params.insert(0, parse_quote!('__kite_schema)); + let (from_impl_generics, _, from_where_clause) = from_generics.split_for_impl(); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); Ok(quote! { - impl #impl_generics ::core::convert::From<(&::kite_sql::types::tuple::SchemaRef, ::kite_sql::types::tuple::Tuple)> + impl #from_impl_generics ::core::convert::From<(&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)> for #struct_name #ty_generics - #where_clause + #from_where_clause { - fn from((schema, mut tuple): (&::kite_sql::types::tuple::SchemaRef, ::kite_sql::types::tuple::Tuple)) -> Self { + fn from((schema, mut tuple): (&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)) -> Self { let mut struct_instance = ::default(); #(#assignments)* @@ -426,8 +420,8 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { ] } - fn columns() -> &'static [::kite_sql::orm::OrmColumn] { - static ORM_COLUMNS: ::std::sync::LazyLock<::std::vec::Vec<::kite_sql::orm::OrmColumn>> = ::std::sync::LazyLock::new(|| { + fn columns() -> &'static [::kite_sql::catalog::column::ColumnCatalog] { + static ORM_COLUMNS: ::std::sync::LazyLock<::std::vec::Vec<::kite_sql::catalog::column::ColumnCatalog>> = ::std::sync::LazyLock::new(|| { vec![ #(#orm_columns),* ] @@ -435,6 +429,12 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { ORM_COLUMNS.as_slice() } + fn indexes() -> &'static [(&'static str, &'static [&'static str], bool)] { + &[ + #(#orm_indexes),* + ] + } + fn params(&self) -> Vec<(&'static str, ::kite_sql::types::value::DataValue)> { vec![ #(#params),* @@ -444,100 +444,6 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { fn primary_key(&self) -> &Self::PrimaryKey { #primary_key_value } - - fn select_statement() -> &'static ::kite_sql::db::Statement { - static SELECT_STATEMENT: ::std::sync::LazyLock<::kite_sql::db::Statement> = ::std::sync::LazyLock::new(|| { - ::kite_sql::orm::orm_select_statement( - #table_name_lit, - <#struct_name #ty_generics as ::kite_sql::orm::Model>::fields(), - ) - }); - &SELECT_STATEMENT - } - - fn insert_statement() -> &'static ::kite_sql::db::Statement { - static INSERT_STATEMENT: ::std::sync::LazyLock<::kite_sql::db::Statement> = ::std::sync::LazyLock::new(|| { - ::kite_sql::orm::orm_insert_statement( - #table_name_lit, - <#struct_name #ty_generics as ::kite_sql::orm::Model>::fields(), - ) - }); - &INSERT_STATEMENT - } - - fn find_statement() -> &'static ::kite_sql::db::Statement { - static FIND_STATEMENT: ::std::sync::LazyLock<::kite_sql::db::Statement> = ::std::sync::LazyLock::new(|| { - ::kite_sql::orm::orm_find_statement( - #table_name_lit, - <#struct_name #ty_generics as ::kite_sql::orm::Model>::fields(), - <#struct_name #ty_generics as ::kite_sql::orm::Model>::primary_key_field(), - ) - }); - &FIND_STATEMENT - } - - fn create_table_statement() -> &'static ::kite_sql::db::Statement { - static CREATE_TABLE_STATEMENT: ::std::sync::LazyLock<::kite_sql::db::Statement> = ::std::sync::LazyLock::new(|| { - ::kite_sql::orm::orm_create_table_statement( - #table_name_lit, - <#struct_name #ty_generics as ::kite_sql::orm::Model>::columns(), - false, - ) - .expect("failed to build ORM create table statement") - }); - &CREATE_TABLE_STATEMENT - } - - fn create_table_if_not_exists_statement() -> &'static ::kite_sql::db::Statement { - static CREATE_TABLE_IF_NOT_EXISTS_STATEMENT: ::std::sync::LazyLock<::kite_sql::db::Statement> = ::std::sync::LazyLock::new(|| { - ::kite_sql::orm::orm_create_table_statement( - #table_name_lit, - <#struct_name #ty_generics as ::kite_sql::orm::Model>::columns(), - true, - ) - .expect("failed to build ORM create table if not exists statement") - }); - &CREATE_TABLE_IF_NOT_EXISTS_STATEMENT - } - - fn create_index_statements() -> &'static [::kite_sql::db::Statement] { - static CREATE_INDEX_STATEMENTS: ::std::sync::LazyLock<::std::vec::Vec<::kite_sql::db::Statement>> = ::std::sync::LazyLock::new(|| { - vec![ - #(#create_index_statements),* - ] - }); - CREATE_INDEX_STATEMENTS.as_slice() - } - - fn create_index_if_not_exists_statements() -> &'static [::kite_sql::db::Statement] { - static CREATE_INDEX_IF_NOT_EXISTS_STATEMENTS: ::std::sync::LazyLock<::std::vec::Vec<::kite_sql::db::Statement>> = ::std::sync::LazyLock::new(|| { - vec![ - #(#create_index_if_not_exists_statements),* - ] - }); - CREATE_INDEX_IF_NOT_EXISTS_STATEMENTS.as_slice() - } - - fn drop_table_statement() -> &'static ::kite_sql::db::Statement { - static DROP_TABLE_STATEMENT: ::std::sync::LazyLock<::kite_sql::db::Statement> = ::std::sync::LazyLock::new(|| { - ::kite_sql::orm::orm_drop_table_statement(#table_name_lit, false) - }); - &DROP_TABLE_STATEMENT - } - - fn drop_table_if_exists_statement() -> &'static ::kite_sql::db::Statement { - static DROP_TABLE_IF_EXISTS_STATEMENT: ::std::sync::LazyLock<::kite_sql::db::Statement> = ::std::sync::LazyLock::new(|| { - ::kite_sql::orm::orm_drop_table_statement(#table_name_lit, true) - }); - &DROP_TABLE_IF_EXISTS_STATEMENT - } - - fn analyze_statement() -> &'static ::kite_sql::db::Statement { - static ANALYZE_STATEMENT: ::std::sync::LazyLock<::kite_sql::db::Statement> = ::std::sync::LazyLock::new(|| { - ::kite_sql::orm::orm_analyze_statement(#table_name_lit) - }); - &ANALYZE_STATEMENT - } } }) } diff --git a/kite_sql_serde_macros/src/projection.rs b/kite_sql_serde_macros/src/projection.rs index 5dad654f..3d617019 100644 --- a/kite_sql_serde_macros/src/projection.rs +++ b/kite_sql_serde_macros/src/projection.rs @@ -33,7 +33,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { )); }; - let mut projected_values = Vec::new(); + let mut projection_exprs = Vec::new(); let mut assignments = Vec::new(); for field in data_struct.fields { @@ -62,13 +62,16 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { .predicates .push(parse_quote!(#field_ty : ::kite_sql::orm::FromDataValue)); - projected_values.push(if rename.is_some() { + projection_exprs.push(if rename.is_some() { quote! { - ::kite_sql::orm::projection_value(#source_name_lit, #relation_expr, #field_name_lit) + { + let expr = scope.column_ref(#relation_expr, #source_name_lit)?; + scope.alias(expr, #field_name_lit) + } } } else { quote! { - ::kite_sql::orm::projection_column(#source_name_lit, #relation_expr) + scope.column_ref(#relation_expr, #source_name_lit)? } }); assignments.push(quote! { @@ -78,21 +81,34 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { }); } + let mut from_generics = generics.clone(); + from_generics.params.insert(0, parse_quote!('__kite_arena)); + from_generics.params.insert(0, parse_quote!('__kite_schema)); + let (from_impl_generics, _, from_where_clause) = from_generics.split_for_impl(); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); Ok(quote! { impl #impl_generics ::kite_sql::orm::Projection for #struct_name #ty_generics #where_clause { - fn projected_values(relation: &str) -> ::std::vec::Vec<::kite_sql::orm::ProjectedValue> { - vec![#(#projected_values),*] + fn bind_projection<'ctx, 'bind, 'parent, 'arena, T, A>( + scope: &mut ::kite_sql::orm::ExprBindScope<'ctx, 'bind, 'parent, 'arena, T, A>, + relation: &str, + ) -> ::std::result::Result<::std::vec::Vec<::kite_sql::expression::ScalarExpression>, ::kite_sql::errors::DatabaseError> + where + T: ::kite_sql::storage::Transaction, + A: AsRef<[(&'static str, ::kite_sql::types::value::DataValue)]>, + { + Ok(::std::vec![ + #(::kite_sql::orm::IntoOrmScalarExpression::into_orm_scalar(#projection_exprs)),* + ]) } } - impl #impl_generics From<(&::kite_sql::types::tuple::SchemaRef, ::kite_sql::types::tuple::Tuple)> for #struct_name #ty_generics - #where_clause + impl #from_impl_generics From<(&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)> for #struct_name #ty_generics + #from_where_clause { - fn from((schema, mut tuple): (&::kite_sql::types::tuple::SchemaRef, ::kite_sql::types::tuple::Tuple)) -> Self { + fn from((schema, mut tuple): (&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)) -> Self { let mut struct_instance = ::default(); #(#assignments)* struct_instance diff --git a/kite_sql_serde_macros/src/reference_serialization.rs b/kite_sql_serde_macros/src/reference_serialization.rs index 81dc9114..3a46a193 100644 --- a/kite_sql_serde_macros/src/reference_serialization.rs +++ b/kite_sql_serde_macros/src/reference_serialization.rs @@ -110,10 +110,10 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { let ty = process_type(&field_opts.ty); encode_fields.push(quote! { - #field_name.encode(writer, is_direct, reference_tables)?; + #field_name.encode(writer, is_direct, reference_tables, arena)?; }); decode_fields.push(quote! { - let #field_name = #ty::decode(reader, drive, reference_tables)?; + let #field_name = #ty::decode(reader, drive, reference_tables, arena)?; }); init_fields.push(quote! { #field_name, @@ -127,11 +127,12 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { quote! { impl crate::serdes::ReferenceSerialization for #struct_name { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut crate::serdes::ReferenceTables, + arena: &A, ) -> Result<(), crate::errors::DatabaseError> { let #init_stream = self; @@ -140,10 +141,11 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { Ok(()) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &crate::serdes::ReferenceTables, + arena: &mut A, ) -> Result { #(#decode_fields)* @@ -173,10 +175,10 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { let ty = process_type(&field_opts.ty); encode_fields.push(quote! { - #field_name.encode(writer, is_direct, reference_tables)?; + #field_name.encode(writer, is_direct, reference_tables, arena)?; }); decode_fields.push(quote! { - let #field_name = #ty::decode(reader, drive, reference_tables)?; + let #field_name = #ty::decode(reader, drive, reference_tables, arena)?; }); init_fields.push(quote! { #field_name, @@ -206,11 +208,12 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { quote! { impl crate::serdes::ReferenceSerialization for #struct_name { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut crate::serdes::ReferenceTables, + arena: &A, ) -> Result<(), crate::errors::DatabaseError> { match self { #(#variant_encode_fields)* @@ -219,10 +222,11 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { Ok(()) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &crate::serdes::ReferenceTables, + arena: &mut A, ) -> Result { let mut type_bytes = [0u8; 1]; std::io::Read::read_exact(reader, &mut type_bytes)?; diff --git a/src/bin/server.rs b/src/bin/server.rs deleted file mode 100644 index cfcace5c..00000000 --- a/src/bin/server.rs +++ /dev/null @@ -1,400 +0,0 @@ -// Copyright 2024 KipData/KiteSQL -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use async_trait::async_trait; -use clap::Parser; -use futures::stream; -use kite_sql::db::{BorrowResultIter, DBTransaction, DataBaseBuilder, Database}; -use kite_sql::errors::DatabaseError; -use kite_sql::storage::rocksdb::RocksStorage; -use kite_sql::types::tuple::{SchemaRef, Tuple}; -use kite_sql::types::LogicalType; -use log::{error, info, LevelFilter}; -use parking_lot::Mutex; -use pgwire::api::auth::noop::NoopStartupHandler; -use pgwire::api::copy::NoopCopyHandler; -use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; -use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; -use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; -use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; -use pgwire::messages::data::DataRow; -use pgwire::tokio::process_socket; -use std::fmt::Debug; -use std::io; -use std::mem::transmute; -use std::ops::{Deref, DerefMut}; -use std::path::PathBuf; -use std::ptr::NonNull; -use std::sync::Arc; -use tokio::net::TcpListener; - -pub(crate) const BANNER: &str = " -oooo oooo o8o . .oooooo..o .oooooo. ooooo -`888 .8P' `\"' .o8 d8P' `Y8 d8P' `Y8b `888' - 888 d8' oooo .o888oo .ooooo. Y88bo. 888 888 888 - 88888[ `888 888 d88' `88b `\"Y8888o. 888 888 888 - 888`88b. 888 888 888ooo888 `\"Y88b 888 888 888 - 888 `88b. 888 888 . 888 .o oo .d8P `88b d88b 888 o -o888o o888o o888o \"888\" `Y8bod8P' 8\"\"88888P' `Y8bood8P'Ybd' o888ooooood8 - -"; - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - #[clap(long, default_value = "127.0.0.1")] - ip: String, - #[clap(long, default_value = "5432")] - port: u16, - #[clap(long, default_value = "./kitesql_data")] - path: String, -} - -struct TransactionPtr(NonNull>); - -impl Deref for TransactionPtr { - type Target = NonNull>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for TransactionPtr { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -unsafe impl Send for TransactionPtr {} -unsafe impl Sync for TransactionPtr {} - -pub struct KiteSQLBackend { - inner: Arc>, -} - -impl KiteSQLBackend { - pub fn new(path: impl Into + Send) -> Result { - let database = DataBaseBuilder::path(path).build_rocksdb()?; - - Ok(KiteSQLBackend { - inner: Arc::new(database), - }) - } -} - -pub struct SessionBackend { - inner: Arc>, - tx: Mutex>, -} - -impl SessionBackend { - pub fn new(inner: Arc>) -> SessionBackend { - SessionBackend { - inner, - tx: Mutex::new(None), - } - } -} - -impl NoopStartupHandler for SessionBackend {} - -struct CustomBackendFactory { - handler: Arc, -} - -impl CustomBackendFactory { - pub fn new(handler: Arc) -> CustomBackendFactory { - CustomBackendFactory { handler } - } -} - -impl PgWireServerHandlers for CustomBackendFactory { - type StartupHandler = SessionBackend; - type SimpleQueryHandler = SessionBackend; - type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; - type CopyHandler = NoopCopyHandler; - type ErrorHandler = NoopErrorHandler; - - fn simple_query_handler(&self) -> Arc { - self.handler.clone() - } - - fn extended_query_handler(&self) -> Arc { - Arc::new(PlaceholderExtendedQueryHandler) - } - - fn startup_handler(&self) -> Arc { - self.handler.clone() - } - - fn copy_handler(&self) -> Arc { - Arc::new(NoopCopyHandler) - } - - fn error_handler(&self) -> Arc { - Arc::new(NoopErrorHandler) - } -} - -#[async_trait] -impl SimpleQueryHandler for SessionBackend { - async fn do_query<'a, 'b: 'a, C>( - &'b self, - _client: &mut C, - query: &'a str, - ) -> PgWireResult>> - where - C: ClientInfo + Unpin + Send + Sync, - { - match query.to_uppercase().as_str() { - "BEGIN;" | "BEGIN" | "START TRANSACTION;" | "START TRANSACTION" => { - let mut guard = self.tx.lock(); - - if guard.is_some() { - return Err(PgWireError::ApiError(Box::new( - DatabaseError::TransactionAlreadyExists, - ))); - } - let transaction = self - .inner - .new_transaction() - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - guard.replace(TransactionPtr( - Box::leak(Box::>::new(unsafe { - transmute::< - DBTransaction<'_, RocksStorage>, - DBTransaction<'static, RocksStorage>, - >(transaction) - })) - .into(), - )); - - Ok(vec![Response::Execution(Tag::new("OK"))]) - } - "COMMIT;" | "COMMIT" | "COMMIT WORK;" | "COMMIT WORK" => { - let mut guard = self.tx.lock(); - - if let Some(transaction) = guard.take() { - unsafe { Box::from_raw(transaction.as_ptr()) } - .commit() - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - - Ok(vec![Response::Execution(Tag::new("OK"))]) - } else { - Err(PgWireError::ApiError(Box::new( - DatabaseError::NoTransactionBegin, - ))) - } - } - "ROLLBACK;" | "ROLLBACK" => { - let mut guard = self.tx.lock(); - - if let Some(transaction) = guard.take() { - unsafe { drop(Box::from_raw(transaction.as_ptr())) } - } else { - return Err(PgWireError::ApiError(Box::new( - DatabaseError::NoTransactionBegin, - ))); - } - - Ok(vec![Response::Execution(Tag::new("OK"))]) - } - _ => { - let mut guard = self.tx.lock(); - - let response = if let Some(transaction) = guard.as_mut() { - let mut iter = unsafe { transaction.as_mut().run(query) } - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let response = encode_query_result(&mut iter)?; - iter.done() - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - response - } else { - let mut iter = self - .inner - .run(query) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let response = encode_query_result(&mut iter)?; - iter.done() - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - response - }; - Ok(vec![Response::Query(response)]) - } - } - } -} - -fn encode_query_result<'a, I>(iter: &mut I) -> PgWireResult> -where - I: BorrowResultIter, -{ - let fields = encode_fields(iter.schema())?; - let mut results = Vec::new(); - - while let Some(tuple) = iter - .next_borrowed_tuple() - .map_err(|e| PgWireError::ApiError(Box::new(e)))? - { - results.push(encode_tuple(fields.clone(), tuple)); - } - - Ok(QueryResponse::new(fields, stream::iter(results))) -} - -fn encode_fields(schema: &SchemaRef) -> PgWireResult>> { - Ok(Arc::new( - schema - .iter() - .map(|column| { - let pg_type = into_pg_type(column.datatype())?; - - Ok(FieldInfo::new( - column.name().into(), - None, - None, - pg_type, - FieldFormat::Text, - )) - }) - .collect::>>()?, - )) -} - -fn encode_tuple(schema: Arc>, tuple: &Tuple) -> PgWireResult { - let mut encoder = DataRowEncoder::new(schema); - for value in &tuple.values { - match value.logical_type() { - LogicalType::SqlNull => encoder.encode_field(&None::), - LogicalType::Boolean => encoder.encode_field(&value.bool()), - LogicalType::Tinyint => encoder.encode_field(&value.i8()), - LogicalType::UTinyint => encoder.encode_field(&value.u8().map(|v| v as i8)), - LogicalType::Smallint => encoder.encode_field(&value.i16()), - LogicalType::USmallint => encoder.encode_field(&value.u16().map(|v| v as i16)), - LogicalType::Integer => encoder.encode_field(&value.i32()), - LogicalType::UInteger => encoder.encode_field(&value.u32()), - LogicalType::Bigint => encoder.encode_field(&value.i64()), - LogicalType::UBigint => encoder.encode_field(&value.u64().map(|v| v as i64)), - LogicalType::Float => encoder.encode_field(&value.float()), - LogicalType::Double => encoder.encode_field(&value.double()), - LogicalType::Char(..) | LogicalType::Varchar(..) => encoder.encode_field(&value.utf8()), - LogicalType::Date => encoder.encode_field(&value.date()), - LogicalType::DateTime => encoder.encode_field(&value.datetime()), - LogicalType::Time(_) => encoder.encode_field(&value.time()), - LogicalType::Decimal(_, _) => { - encoder.encode_field(&value.decimal().map(|decimal| decimal.to_string())) - } - _ => unreachable!(), - }?; - } - - encoder.finish() -} - -fn into_pg_type(data_type: &LogicalType) -> PgWireResult { - Ok(match data_type { - LogicalType::SqlNull => Type::UNKNOWN, - LogicalType::Boolean => Type::BOOL, - LogicalType::Tinyint | LogicalType::UTinyint => Type::CHAR, - LogicalType::Smallint | LogicalType::USmallint => Type::INT2, - LogicalType::Integer | LogicalType::UInteger => Type::INT4, - LogicalType::Bigint | LogicalType::UBigint => Type::INT8, - LogicalType::Float => Type::FLOAT4, - LogicalType::Double => Type::FLOAT8, - LogicalType::Varchar(..) => Type::VARCHAR, - LogicalType::Date | LogicalType::DateTime => Type::DATE, - LogicalType::Char(..) => Type::CHAR, - LogicalType::Time(_) => Type::TIME, - LogicalType::Decimal(_, _) => Type::NUMERIC, - _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Unsupported Datatype {data_type}"), - )))); - } - }) -} - -async fn quit() -> io::Result<()> { - #[cfg(unix)] - { - let mut interrupt = - tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?; - let mut terminate = - tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?; - tokio::select! { - _ = interrupt.recv() => (), - _ = terminate.recv() => (), - } - Ok(()) - } - #[cfg(windows)] - { - let mut signal = tokio::signal::windows::ctrl_c()?; - let _ = signal.recv().await; - - Ok(()) - } -} - -#[tokio::main(worker_threads = 8)] -async fn main() { - env_logger::Builder::new() - .filter_level(LevelFilter::Info) - .init(); - - let args = Args::parse(); - info!("{} \nVersion: {}\n", BANNER, env!("CARGO_PKG_VERSION")); - info!(":) Welcome to the KiteSQL🪁"); - info!("Listen on port {}", args.port); - info!("Tips: "); - info!( - "1. all data is in the \'{}\' folder in the directory where the application is run", - args.path - ); - - let backend = KiteSQLBackend::new(args.path).unwrap(); - let factory = Arc::new(CustomBackendFactory::new(Arc::new(SessionBackend::new( - backend.inner, - )))); - let server_addr = format!("{}:{}", args.ip, args.port); - let listener = TcpListener::bind(server_addr).await.unwrap(); - - tokio::select! { - res = server_run(listener,factory) => { - if let Err(err) = res { - error!("[Listener][Failed To Accept]: {err}"); - } - } - _ = quit() => info!("Bye!") - } -} - -async fn server_run( - listener: TcpListener, - factory_ref: Arc, -) -> io::Result<()> { - loop { - let incoming_socket = listener.accept().await?; - let factory_ref = factory_ref.clone(); - - tokio::spawn(async move { - if let Err(err) = process_socket(incoming_socket.0, None, factory_ref).await { - error!("Failed To Process: {err}"); - } - }); - } -} diff --git a/src/bin/shell.rs b/src/bin/shell.rs index 10b58f8a..23bb7dae 100644 --- a/src/bin/shell.rs +++ b/src/bin/shell.rs @@ -174,14 +174,37 @@ Transaction commands: where I: BorrowResultIter, { - let mut table = Table::new(); - let schema = iter.schema().clone(); + let (table, schema_len, row_count) = create_table(&mut iter)?; + iter.done()?; - if !schema.is_empty() { - let header = schema - .iter() - .map(|column| Cell::new(column.full_name())) - .collect::>(); + if schema_len == 0 { + println!("OK"); + } else if row_count == 0 { + println!("{table}"); + println!("0 rows"); + } else { + println!("{table}"); + println!("{row_count} row{}", if row_count == 1 { "" } else { "s" }); + } + + Ok(()) + } + + fn create_table(iter: &mut I) -> Result<(Table, usize, usize), DatabaseError> + where + I: BorrowResultIter, + { + let mut table = Table::new(); + let (header, schema_len) = iter.schema(|schema| { + ( + schema + .iter() + .map(|column| Cell::new(column.full_name())) + .collect::>(), + schema.len(), + ) + }); + if !header.is_empty() { table.set_header(header); } @@ -195,19 +218,8 @@ Transaction commands: .collect::>(); table.add_row(row); } - iter.done()?; - - if schema.is_empty() { - println!("OK"); - } else if row_count == 0 { - println!("{table}"); - println!("0 rows"); - } else { - println!("{table}"); - println!("{row_count} row{}", if row_count == 1 { "" } else { "s" }); - } - Ok(()) + Ok((table, schema_len, row_count)) } fn run_sql<'a>( diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index f372b414..d88d4dc3 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -14,7 +14,6 @@ use ahash::RandomState; use itertools::Itertools; -use sqlparser::ast::{Expr, OrderByExpr}; use std::collections::HashSet; use super::{Binder, QueryBindStep}; @@ -55,16 +54,11 @@ impl> Binder<'_, '_, T, A> Ok(()) } - pub fn extract_group_by_aggregate( + pub fn extract_group_by_aggregate_exprs( &mut self, select_list: &mut [ScalarExpression], - groupby: &[Expr], + mut group_by_exprs: Vec, ) -> Result<(), DatabaseError> { - let mut group_by_exprs = Vec::with_capacity(groupby.len()); - for expr in groupby.iter() { - group_by_exprs.push(self.bind_expr(expr)?); - } - self.validate_groupby_illegal_column(select_list, &group_by_exprs)?; for expr in group_by_exprs.iter_mut() { @@ -73,48 +67,53 @@ impl> Binder<'_, '_, T, A> Ok(()) } - pub fn extract_having_orderby_aggregate( + pub fn extract_having_orderby_aggregate_exprs( &mut self, - having: &Option, - orderbys: &[OrderByExpr], - ) -> Result<(Option, Option>), DatabaseError> { - // Extract having expression. - let return_having = if let Some(having) = having { - let mut having = self.bind_expr(having)?; - self.visit_column_agg_expr(&mut having)?; - - Some(having) - } else { - None - }; - - // Extract orderby expression. - let return_orderby = if !orderbys.is_empty() { - let mut return_orderby = vec![]; - for orderby in orderbys { - let OrderByExpr { expr, options, .. } = orderby; - let mut expr = self.bind_expr(expr)?; - self.visit_column_agg_expr(&mut expr)?; - - return_orderby.push(SortField::new( - expr, - options.asc.is_none_or(|asc| asc), - options.nulls_first.unwrap_or(false), - )); + mut having: Option, + orderby: Option, + mut bind_sort_field: F, + ) -> Result<(Option, Option>), DatabaseError> + where + I: IntoIterator, + F: FnMut(&mut Self, I::Item) -> Result, + { + if let Some(having) = having.as_mut() { + self.visit_column_agg_expr(having)?; + } + let mut return_orderby = None; + if let Some(orderby) = orderby { + let mut fields = Vec::new(); + for orderby in orderby { + let mut field = bind_sort_field(self, orderby)?; + self.visit_column_agg_expr(&mut field.expr)?; + fields.push(field); } - Some(return_orderby) - } else { - None - }; - Ok((return_having, return_orderby)) + return_orderby = Some(fields); + } + Ok((having, return_orderby)) } pub fn bind_aggregate_output_exprs<'c>( + &mut self, + exprs: impl IntoIterator, + arena: &mut crate::planner::PlanArena, + ) -> Result<(), DatabaseError> { + self.bind_aggregate_output_exprs_with_outputs( + &self.context.agg_calls, + &self.context.group_by_exprs, + exprs, + arena, + ) + } + + pub(crate) fn bind_aggregate_output_exprs_with_outputs<'c>( &self, + agg_calls: &[ScalarExpression], + group_by_exprs: &[ScalarExpression], exprs: impl IntoIterator, + arena: &mut crate::planner::PlanArena, ) -> Result<(), DatabaseError> { - let mut binder = - AggregateOutputBinder::new(&self.context.agg_calls, &self.context.group_by_exprs); + let mut binder = AggregateOutputBinder::new(agg_calls, group_by_exprs, arena); for expr in exprs { binder.visit(expr)?; } @@ -470,20 +469,30 @@ impl> Binder<'_, '_, T, A> } } -struct AggregateOutputBinder<'a> { +struct AggregateOutputBinder<'a, 'p> { agg_calls: &'a [ScalarExpression], group_by_exprs: &'a [ScalarExpression], + arena: &'a mut crate::planner::PlanArena<'p>, } -impl<'a> AggregateOutputBinder<'a> { - fn new(agg_calls: &'a [ScalarExpression], group_by_exprs: &'a [ScalarExpression]) -> Self { +impl<'a, 'p> AggregateOutputBinder<'a, 'p> { + fn new( + agg_calls: &'a [ScalarExpression], + group_by_exprs: &'a [ScalarExpression], + arena: &'a mut crate::planner::PlanArena<'p>, + ) -> Self { Self { agg_calls, group_by_exprs, + arena, } } - fn output_ref(&self, expr: &ScalarExpression) -> Option { + fn output_ref( + &mut self, + expr: &ScalarExpression, + ) -> Result, DatabaseError> { + let output_count = self.agg_calls.len() + self.group_by_exprs.len(); self.agg_calls .iter() .chain(self.group_by_exprs.iter()) @@ -496,13 +505,21 @@ impl<'a> AggregateOutputBinder<'a> { .iter() .chain(self.group_by_exprs.iter()) .nth(position) - .unwrap(); - ScalarExpression::column_expr(output_expr.output_column(), position) + .ok_or_else(|| { + DatabaseError::InvalidValue(format!( + "aggregate output position {position} is out of bounds for {output_count} output expressions" + )) + })?; + Ok(ScalarExpression::column_expr( + output_expr.output_column_ref(self.arena), + position, + )) }) + .transpose() } } -impl<'a> VisitorMut<'a> for AggregateOutputBinder<'_> { +impl<'a> VisitorMut<'a> for AggregateOutputBinder<'_, '_> { fn visit(&mut self, expr: &'a mut ScalarExpression) -> Result<(), DatabaseError> { if let ScalarExpression::Alias { expr: inner_expr, @@ -512,7 +529,7 @@ impl<'a> VisitorMut<'a> for AggregateOutputBinder<'_> { return self.visit(inner_expr); } - if let Some(output_ref) = self.output_ref(expr) { + if let Some(output_ref) = self.output_ref(expr)? { *expr = output_ref; return Ok(()); } @@ -528,10 +545,11 @@ mod tests { use crate::expression::agg::AggKind; use crate::expression::visitor_mut::VisitorMut; use crate::expression::{AliasType, ScalarExpression}; + use crate::planner::PlanArena; use crate::types::LogicalType; - fn test_column(name: &str, ty: LogicalType) -> ColumnRef { - ColumnRef::from(ColumnCatalog::new( + fn test_column(arena: &mut PlanArena, name: &str, ty: LogicalType) -> ColumnRef { + arena.alloc_column(ColumnCatalog::new( name.to_string(), true, ColumnDesc::new(ty, None, false, None).unwrap(), @@ -549,8 +567,10 @@ mod tests { #[test] fn test_aggregate_output_binder_rewrites_agg_and_group_slots() -> Result<(), DatabaseError> { - let group_column = test_column("c1", LogicalType::Integer); - let agg_column = test_column("c2", LogicalType::Integer); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); + let group_column = test_column(&mut arena, "c1", LogicalType::Integer); + let agg_column = test_column(&mut arena, "c2", LogicalType::Integer); let group_expr = ScalarExpression::column_expr(group_column, 0); let agg_expr = test_count(ScalarExpression::column_expr(agg_column, 1)); @@ -564,54 +584,59 @@ mod tests { alias: AliasType::Name("g".to_string()), }; - let mut binder = AggregateOutputBinder::new( - std::slice::from_ref(&agg_output), - std::slice::from_ref(&group_output), - ); - let mut order_by_agg = ScalarExpression::Alias { expr: Box::new(agg_expr), alias: AliasType::Name("cnt".to_string()), }; - binder.visit(&mut order_by_agg)?; - assert_eq!( - order_by_agg, - ScalarExpression::Alias { - expr: Box::new(ScalarExpression::column_expr(agg_output.output_column(), 0)), - alias: AliasType::Name("cnt".to_string()), - } - ); - let mut order_by_group = group_expr; - binder.visit(&mut order_by_group)?; - assert_eq!( - order_by_group, - ScalarExpression::column_expr(group_output.output_column(), 1) - ); + { + let mut binder = AggregateOutputBinder::new( + std::slice::from_ref(&agg_output), + std::slice::from_ref(&group_output), + &mut arena, + ); + binder.visit(&mut order_by_agg)?; + binder.visit(&mut order_by_group)?; + } + let expected_agg = ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr( + agg_output.output_column_ref(&mut arena), + 0, + )), + alias: AliasType::Name("cnt".to_string()), + }; + assert!(order_by_agg.eq_ignore_colref_pos(&expected_agg, &arena)); + + let expected_group = + ScalarExpression::column_expr(group_output.output_column_ref(&mut arena), 1); + assert!(order_by_group.eq_ignore_colref_pos(&expected_group, &arena)); Ok(()) } #[test] fn test_aggregate_output_binder_matches_alias_expr_reference() -> Result<(), DatabaseError> { - let group_column = test_column("c1", LogicalType::Integer); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); + let group_column = test_column(&mut arena, "c1", LogicalType::Integer); let group_expr = ScalarExpression::column_expr(group_column, 0); let group_output = ScalarExpression::Alias { expr: Box::new(group_expr.clone()), alias: AliasType::Name("g".to_string()), }; - let mut binder = AggregateOutputBinder::new(&[], std::slice::from_ref(&group_output)); let mut target = ScalarExpression::Alias { expr: Box::new(ScalarExpression::Constant(1_i32.into())), alias: AliasType::Expr(Box::new(group_expr)), }; - binder.visit(&mut target)?; - assert_eq!( - target, - ScalarExpression::column_expr(group_output.output_column(), 0) - ); + { + let mut binder = + AggregateOutputBinder::new(&[], std::slice::from_ref(&group_output), &mut arena); + binder.visit(&mut target)?; + } + let expected = ScalarExpression::column_expr(group_output.output_column_ref(&mut arena), 0); + assert!(target.eq_ignore_colref_pos(&expected, &arena)); Ok(()) } diff --git a/src/binder/alter_table.rs b/src/binder/alter_table.rs index 03e3deef..b18e27f8 100644 --- a/src/binder/alter_table.rs +++ b/src/binder/alter_table.rs @@ -12,294 +12,72 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlparser::ast::{AlterColumnOperation, AlterTableOperation, ColumnOption, ObjectName}; - -use std::borrow::Cow; - -use super::{attach_span_if_absent, is_valid_identifier, Binder}; -use crate::binder::{lower_case_name, lower_ident}; -use crate::catalog::TableName; +use super::Binder; +use crate::catalog::{ColumnCatalog, TableName}; use crate::errors::DatabaseError; -use crate::expression::ScalarExpression; use crate::planner::operator::alter_table::add_column::AddColumnOperator; use crate::planner::operator::alter_table::change_column::{ ChangeColumnOperator, DefaultChange, NotNullChange, }; use crate::planner::operator::alter_table::drop_column::DropColumnOperator; use crate::planner::operator::Operator; -use crate::planner::Childrens; -use crate::planner::LogicalPlan; +use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; use crate::types::LogicalType; impl> Binder<'_, '_, T, A> { - fn bind_alter_default_expr( + pub(crate) fn bind_add_column( &mut self, - expr: &sqlparser::ast::Expr, - ty: &LogicalType, - ) -> Result { - let mut expr = self.bind_expr(expr)?; - - if expr.any_referenced_column(true, |_| true) { - return Err(DatabaseError::UnsupportedStmt( - "column is not allowed to exist in default".to_string(), - )); - } - expr = ScalarExpression::type_cast(expr, Cow::Borrowed(ty))?; - - Ok(expr) + table_name: TableName, + column: ColumnCatalog, + if_not_exists: bool, + ) -> Result { + Ok(LogicalPlan::new( + Operator::AddColumn(AddColumnOperator { + table_name, + if_not_exists, + column, + }), + Childrens::None, + )) } - fn bind_change_column_options( + pub(crate) fn bind_drop_column( &mut self, - options: &[ColumnOption], - data_type: &LogicalType, - ) -> Result<(DefaultChange, NotNullChange), DatabaseError> { - let mut default_change = DefaultChange::NoChange; - let mut not_null_change = NotNullChange::NoChange; - - for option in options { - match option { - ColumnOption::Null => not_null_change = NotNullChange::Drop, - ColumnOption::NotNull => not_null_change = NotNullChange::Set, - ColumnOption::Default(expr) => { - default_change = - DefaultChange::Set(self.bind_alter_default_expr(expr, data_type)?); - } - option => { - return Err(DatabaseError::UnsupportedStmt(format!( - "CHANGE/MODIFY COLUMN does not currently support this option: {option:?}" - ))) - } - } - } - - Ok((default_change, not_null_change)) + table_name: TableName, + column_name: String, + if_exists: bool, + ) -> Result { + Ok(LogicalPlan::new( + Operator::DropColumn(DropColumnOperator { + table_name, + if_exists, + column_name, + }), + Childrens::None, + )) } - pub(crate) fn bind_alter_table( + pub(crate) fn bind_change_column( &mut self, - name: &ObjectName, - operation: &AlterTableOperation, + table_name: TableName, + old_column_name: String, + new_column_name: String, + data_type: LogicalType, + default_change: DefaultChange, + not_null_change: NotNullChange, ) -> Result { - let table_name: TableName = lower_case_name(name)?.into(); - let table = self - .context - .table(table_name.clone())? - .ok_or(DatabaseError::TableNotFound)?; - let plan = match operation { - AlterTableOperation::AddColumn { - column_keyword: _, - if_not_exists, - column_def, - .. - } => { - let column = self.bind_column(column_def, None)?; - - if !is_valid_identifier(column.name()) { - return Err(attach_span_if_absent( - DatabaseError::invalid_column("illegal column naming".to_string()), - column_def, - )); - } - LogicalPlan::new( - Operator::AddColumn(AddColumnOperator { - table_name, - if_not_exists: *if_not_exists, - column, - }), - Childrens::None, - ) - } - AlterTableOperation::DropColumn { - column_names, - if_exists, - .. - } => { - if column_names.len() != 1 { - return Err(DatabaseError::UnsupportedStmt( - "only dropping a single column is supported".to_string(), - )); - } - let column_name = column_names[0].value.clone(); - - LogicalPlan::new( - Operator::DropColumn(DropColumnOperator { - table_name, - if_exists: *if_exists, - column_name, - }), - Childrens::None, - ) - } - AlterTableOperation::RenameColumn { + Ok(LogicalPlan::new( + Operator::ChangeColumn(ChangeColumnOperator { + table_name, old_column_name, new_column_name, - } => { - let old_column_name = lower_ident(old_column_name); - let new_column_name = lower_ident(new_column_name).into_owned(); - let old_column = table - .get_column_by_name(old_column_name.as_ref()) - .ok_or_else(|| DatabaseError::column_not_found(old_column_name.to_string()))?; - - if !is_valid_identifier(&new_column_name) { - return Err(DatabaseError::invalid_column( - "illegal column naming".to_string(), - )); - } - - LogicalPlan::new( - Operator::ChangeColumn(ChangeColumnOperator { - table_name, - old_column_name: old_column_name.into_owned(), - new_column_name, - data_type: old_column.datatype().clone(), - default_change: DefaultChange::NoChange, - not_null_change: NotNullChange::NoChange, - }), - Childrens::None, - ) - } - AlterTableOperation::AlterColumn { column_name, op } => { - let old_column_name = lower_ident(column_name); - let old_column = table - .get_column_by_name(old_column_name.as_ref()) - .ok_or_else(|| DatabaseError::column_not_found(old_column_name.to_string()))?; - let old_data_type = old_column.datatype().clone(); - - let (data_type, default_change, not_null_change) = match op { - AlterColumnOperation::SetDataType { - data_type, using, .. - } => { - if using.is_some() { - return Err(DatabaseError::UnsupportedStmt( - "ALTER COLUMN TYPE USING is not supported".to_string(), - )); - } - ( - LogicalType::try_from(data_type.clone())?, - DefaultChange::NoChange, - NotNullChange::NoChange, - ) - } - AlterColumnOperation::SetDefault { value } => ( - old_data_type.clone(), - DefaultChange::Set(self.bind_alter_default_expr(value, &old_data_type)?), - NotNullChange::NoChange, - ), - AlterColumnOperation::DropDefault => ( - old_data_type.clone(), - DefaultChange::Drop, - NotNullChange::NoChange, - ), - AlterColumnOperation::SetNotNull => ( - old_data_type.clone(), - DefaultChange::NoChange, - NotNullChange::Set, - ), - AlterColumnOperation::DropNotNull => ( - old_data_type.clone(), - DefaultChange::NoChange, - NotNullChange::Drop, - ), - _ => { - return Err(DatabaseError::UnsupportedStmt(format!( - "unsupported alter column operation: {op:?}" - ))) - } - }; - - LogicalPlan::new( - Operator::ChangeColumn(ChangeColumnOperator { - table_name, - new_column_name: old_column_name.to_string(), - old_column_name: old_column_name.into_owned(), - data_type, - default_change, - not_null_change, - }), - Childrens::None, - ) - } - AlterTableOperation::ModifyColumn { - col_name, data_type, - options, - column_position, - } => { - if column_position.is_some() { - return Err(DatabaseError::UnsupportedStmt( - "MODIFY COLUMN does not currently support column positions".to_string(), - )); - } - let old_column_name = lower_ident(col_name); - let _ = table - .get_column_by_name(old_column_name.as_ref()) - .ok_or_else(|| DatabaseError::column_not_found(old_column_name.to_string()))?; - let old_column_name = old_column_name.into_owned(); - let data_type = LogicalType::try_from(data_type.clone())?; - let (default_change, not_null_change) = - self.bind_change_column_options(options, &data_type)?; - - LogicalPlan::new( - Operator::ChangeColumn(ChangeColumnOperator { - table_name, - new_column_name: old_column_name.clone(), - old_column_name, - data_type, - default_change, - not_null_change, - }), - Childrens::None, - ) - } - AlterTableOperation::ChangeColumn { - old_name, - new_name, - data_type, - options, - column_position, - } => { - if column_position.is_some() { - return Err(DatabaseError::UnsupportedStmt( - "CHANGE COLUMN does not currently support column positions".to_string(), - )); - } - let old_column_name = lower_ident(old_name); - let new_column_name = lower_ident(new_name).into_owned(); - let _ = table - .get_column_by_name(old_column_name.as_ref()) - .ok_or_else(|| DatabaseError::column_not_found(old_column_name.to_string()))?; - - if !is_valid_identifier(&new_column_name) { - return Err(DatabaseError::invalid_column( - "illegal column naming".to_string(), - )); - } - let data_type = LogicalType::try_from(data_type.clone())?; - let (default_change, not_null_change) = - self.bind_change_column_options(options, &data_type)?; - - LogicalPlan::new( - Operator::ChangeColumn(ChangeColumnOperator { - table_name, - old_column_name: old_column_name.into_owned(), - new_column_name, - data_type, - default_change, - not_null_change, - }), - Childrens::None, - ) - } - op => { - return Err(DatabaseError::UnsupportedStmt(format!( - "AlertOperation: {op:?}" - ))) - } - }; - - Ok(plan) + default_change, + not_null_change, + }), + Childrens::None, + )) } } diff --git a/src/binder/analyze.rs b/src/binder/analyze.rs index adf4d0da..b5b394b0 100644 --- a/src/binder/analyze.rs +++ b/src/binder/analyze.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, Binder, Source}; +use crate::binder::{Binder, Source}; use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::planner::operator::analyze::AnalyzeOperator; @@ -21,12 +21,13 @@ use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use sqlparser::ast::ObjectName; impl> Binder<'_, '_, T, A> { - pub(crate) fn bind_analyze(&mut self, name: &ObjectName) -> Result { - let table_name: TableName = lower_case_name(name)?.into(); - + pub(crate) fn bind_analyze( + &mut self, + table_name: TableName, + arena: &crate::planner::PlanArena, + ) -> Result { let table = self .context .source_and_bind(table_name.clone(), None, None, true)? @@ -40,7 +41,7 @@ impl> Binder<'_, '_, T, A> .ok_or(DatabaseError::TableNotFound)?; let index_metas = table.indexes.clone(); - let scan_op = TableScanOperator::build(table_name.clone(), table, false)?; + let scan_op = TableScanOperator::build(table_name.clone(), table, false, arena)?; Ok(LogicalPlan::new( Operator::Analyze(AnalyzeOperator { table_name, diff --git a/src/binder/copy.rs b/src/binder/copy.rs index 667242c8..ad33ab4e 100644 --- a/src/binder/copy.rs +++ b/src/binder/copy.rs @@ -17,15 +17,12 @@ use std::str::FromStr; use super::*; use crate::catalog::TableName; -use crate::errors::DatabaseError; use crate::planner::operator::copy_from_file::CopyFromFileOperator; use crate::planner::operator::copy_to_file::CopyToFileOperator; use crate::planner::operator::table_scan::TableScanOperator; use crate::planner::operator::Operator; use crate::planner::Childrens; use kite_sql_serde_macros::ReferenceSerialization; -use serde::{Deserialize, Serialize}; -use sqlparser::ast::{CopyOption, CopySource, CopyTarget}; #[derive(Debug, PartialEq, PartialOrd, Ord, Hash, Eq, Clone, ReferenceSerialization)] pub struct ExtSource { @@ -34,18 +31,7 @@ pub struct ExtSource { } /// File format. -#[derive( - Debug, - PartialEq, - PartialOrd, - Ord, - Hash, - Eq, - Clone, - Serialize, - Deserialize, - ReferenceSerialization, -)] +#[derive(Debug, PartialEq, PartialOrd, Ord, Hash, Eq, Clone, ReferenceSerialization)] pub enum FileFormat { Csv { /// Delimiter to parse. @@ -79,55 +65,34 @@ impl FromStr for ExtSource { } impl> Binder<'_, '_, T, A> { - pub(super) fn bind_copy( + pub(super) fn bind_copy_to_file( &mut self, - source: CopySource, - to: bool, - target: CopyTarget, - options: &[CopyOption], + target: ExtSource, + input: LogicalPlan, ) -> Result { - let ext_source = copy_ext_source(target, options)?; - - let (table_name, ..) = match source { - CopySource::Table { - table_name, - columns, - } => (table_name, columns), - CopySource::Query(query) => { - if !to { - return Err(DatabaseError::UnsupportedStmt( - "'COPY FROM query'".to_string(), - )); - } - let mut input_plan = self.bind_query(&query)?; - let schema_ref = input_plan.output_schema().clone(); - return Ok(LogicalPlan::new( - Operator::CopyToFile(CopyToFileOperator { - target: ext_source, - schema_ref, - }), - Childrens::Only(Box::new(input_plan)), - )); - } - }; - let table_name: TableName = lower_case_name(&table_name)?.into(); - - if let Some(table) = self.context.table(table_name.clone())? { - let schema_ref = table.schema_ref().clone(); + Ok(LogicalPlan::new( + Operator::CopyToFile(CopyToFileOperator { target }), + Childrens::Only(Box::new(input)), + )) + } + pub(super) fn bind_copy_table( + &mut self, + table_name: TableName, + to: bool, + ext_source: ExtSource, + arena: &crate::planner::PlanArena, + ) -> Result { + if let Some(table) = self.context.table(table_name.clone())?.cloned() { if to { - // COPY TO Ok(LogicalPlan::new( - Operator::CopyToFile(CopyToFileOperator { - target: ext_source, - schema_ref, - }), + Operator::CopyToFile(CopyToFileOperator { target: ext_source }), Childrens::Only(Box::new(TableScanOperator::build( - table_name, table, false, + table_name, &table, false, arena, )?)), )) } else { - // COPY FROM + let schema_ref = table.columns().copied().collect(); Ok(LogicalPlan::new( Operator::CopyFromFile(CopyFromFileOperator { source: ext_source, @@ -142,45 +107,3 @@ impl> Binder<'_, '_, T, A> } } } - -fn copy_ext_source(target: CopyTarget, options: &[CopyOption]) -> Result { - Ok(ExtSource { - path: match target { - CopyTarget::File { filename } => filename.into(), - t => { - return Err(DatabaseError::UnsupportedStmt(format!( - "copy target: {t:?}" - ))) - } - }, - format: FileFormat::from_options(options), - }) -} - -impl FileFormat { - /// Create from copy options. - pub fn from_options(options: &[CopyOption]) -> Self { - let mut delimiter = ','; - let mut quote = '"'; - let mut escape = None; - let mut header = false; - for opt in options { - match opt { - CopyOption::Format(fmt) => { - debug_assert_eq!(fmt.value.to_lowercase(), "csv", "only support CSV format") - } - CopyOption::Delimiter(c) => delimiter = *c, - CopyOption::Header(b) => header = *b, - CopyOption::Quote(c) => quote = *c, - CopyOption::Escape(c) => escape = Some(*c), - o => panic!("unsupported copy option: {o:?}"), - } - } - FileFormat::Csv { - delimiter, - quote, - escape, - header, - } - } -} diff --git a/src/binder/create_index.rs b/src/binder/create_index.rs index 7aa3f259..44187aec 100644 --- a/src/binder/create_index.rs +++ b/src/binder/create_index.rs @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, Binder, Source}; -use crate::catalog::TableName; +use crate::binder::{Binder, Source}; +use crate::catalog::{ColumnRef, TableName}; use crate::errors::DatabaseError; -use crate::expression::ScalarExpression; use crate::planner::operator::create_index::CreateIndexOperator; use crate::planner::operator::table_scan::TableScanOperator; use crate::planner::operator::Operator; @@ -23,65 +22,54 @@ use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::index::IndexType; use crate::types::value::DataValue; -use sqlparser::ast::{IndexColumn, ObjectName}; impl> Binder<'_, '_, T, A> { + pub(crate) fn bind_create_index_source( + &mut self, + table_name: TableName, + arena: &mut crate::planner::PlanArena, + ) -> Result { + let source = self + .context + .source_and_bind(table_name.clone(), None, None, false)? + .ok_or(DatabaseError::SourceNotFound)?; + match source { + Source::Table(table) => { + TableScanOperator::build(table_name.clone(), table, true, arena) + } + Source::View(view) => Ok(LogicalPlan::clone(&view.plan)), + Source::Schema(_) => Err(DatabaseError::UnsupportedStmt( + "derived source cannot be rebound as a base relation".to_string(), + )), + } + } + pub(crate) fn bind_create_index( &mut self, - table_name: &ObjectName, - name: Option<&ObjectName>, - index_columns: &[IndexColumn], + table_name: TableName, + index_name: String, + columns: Vec, if_not_exists: bool, is_unique: bool, + input: LogicalPlan, ) -> Result { - let table_name: TableName = lower_case_name(table_name)?.into(); - let index_name = name - .ok_or(DatabaseError::InvalidIndex) - .and_then(lower_case_name)?; let ty = if is_unique { IndexType::Unique - } else if index_columns.len() == 1 { + } else if columns.len() == 1 { IndexType::Normal } else { IndexType::Composite }; - let source = self - .context - .source_and_bind(table_name.clone(), None, None, false)? - .ok_or(DatabaseError::SourceNotFound)?; - let plan = match source { - Source::Table(table) => TableScanOperator::build(table_name.clone(), table, true)?, - Source::View(view) => LogicalPlan::clone(&view.plan), - Source::Schema(_) => { - return Err(DatabaseError::UnsupportedStmt( - "derived source cannot be rebound as a base relation".to_string(), - )) - } - }; - let mut columns = Vec::with_capacity(index_columns.len()); - - for index_column in index_columns { - // TODO: Expression Index - match self.bind_expr(&index_column.column.expr)? { - ScalarExpression::ColumnRef { column, .. } => columns.push(column), - expr => { - return Err(DatabaseError::UnsupportedStmt(format!( - "'CREATE INDEX' by {expr}" - ))) - } - } - } - Ok(LogicalPlan::new( Operator::CreateIndex(CreateIndexOperator { table_name, columns, - index_name: index_name.into_owned(), + index_name, if_not_exists, ty, }), - Childrens::Only(Box::new(plan)), + Childrens::Only(Box::new(input)), )) } } diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 9a043955..f4e56a56 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -12,86 +12,39 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::{attach_span_if_absent, is_valid_identifier, Binder}; -use crate::binder::{lower_case_name, lower_ident}; -use crate::catalog::{ColumnCatalog, ColumnDesc, TableName}; +use super::{is_valid_identifier, Binder}; +use crate::catalog::{ColumnCatalog, TableName}; use crate::errors::DatabaseError; -use crate::expression::ScalarExpression; use crate::planner::operator::create_table::CreateTableOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use crate::types::LogicalType; -use itertools::Itertools; -use sqlparser::ast::{ColumnDef, ColumnOption, Expr, IndexColumn, ObjectName, TableConstraint}; -use std::borrow::Cow; use std::collections::HashSet; impl> Binder<'_, '_, T, A> { // TODO: TableConstraint pub(crate) fn bind_create_table( &mut self, - name: &ObjectName, - columns: &[ColumnDef], - constraints: &[TableConstraint], + table_name: TableName, + columns: Vec, if_not_exists: bool, ) -> Result { - let table_name: TableName = lower_case_name(name)?.into(); - - if !is_valid_identifier(&table_name) { - return Err(attach_span_if_absent( - DatabaseError::invalid_table("illegal table naming".to_string()), - name, - )); - } - { - // check duplicated column names - let mut set = HashSet::new(); - for col in columns.iter() { - let col_name = &col.name.value; - if !set.insert(col_name) { - return Err(DatabaseError::DuplicateColumn(col_name.clone())); - } - if !is_valid_identifier(col_name) { - return Err(attach_span_if_absent( - DatabaseError::invalid_column("illegal column naming".to_string()), - col, - )); - } + let mut names = HashSet::new(); + for column in &columns { + if !names.insert(column.name()) { + return Err(DatabaseError::DuplicateColumn(column.name().to_string())); } - } - let mut columns: Vec = columns - .iter() - .enumerate() - .map(|(i, col)| self.bind_column(col, Some(i))) - .try_collect()?; - for constraint in constraints { - match constraint { - TableConstraint::PrimaryKey(primary) => { - Self::bind_constraint(&mut columns, &primary.columns, |i, desc| { - desc.set_primary(Some(i)) - })?; - } - TableConstraint::Unique(unique) => { - Self::bind_constraint(&mut columns, &unique.columns, |_, desc| { - desc.set_unique() - })?; - } - constraint => { - return Err(DatabaseError::UnsupportedStmt(format!( - "`CreateTable` does not currently support this constraint: {constraint:?}" - )))? - } + if !is_valid_identifier(column.name()) { + return Err(DatabaseError::invalid_column( + "illegal column naming".to_string(), + )); } } if columns.iter().filter(|col| col.desc().is_primary()).count() == 0 { - return Err(attach_span_if_absent( - DatabaseError::invalid_table( - "the primary key field must exist and have at least one".to_string(), - ), - name, + return Err(DatabaseError::invalid_table( + "the primary key field must exist and have at least one".to_string(), )); } @@ -104,79 +57,6 @@ impl> Binder<'_, '_, T, A> Childrens::None, )) } - - fn bind_constraint( - table_columns: &mut [ColumnCatalog], - exprs: &[IndexColumn], - fn_constraint: F, - ) -> Result<(), DatabaseError> { - for (i, index_column) in exprs.iter().enumerate() { - let Expr::Identifier(ident) = &index_column.column.expr else { - return Err(DatabaseError::UnsupportedStmt( - "only identifier columns are supported in `PRIMARY KEY/UNIQUE`".to_string(), - )); - }; - let column_name = lower_ident(ident); - - if let Some(column) = table_columns - .iter_mut() - .find(|column| column.name() == column_name.as_ref()) - { - fn_constraint(i, column.desc_mut()) - } - } - Ok(()) - } - - pub fn bind_column( - &mut self, - column_def: &ColumnDef, - column_index: Option, - ) -> Result { - let column_name = lower_ident(&column_def.name).into_owned(); - let mut column_desc = ColumnDesc::new( - LogicalType::try_from(column_def.data_type.clone())?, - None, - false, - None, - )?; - let mut nullable = true; - - for option_def in &column_def.options { - match &option_def.option { - ColumnOption::Null => nullable = true, - ColumnOption::NotNull => nullable = false, - ColumnOption::PrimaryKey(_) => { - column_desc.set_primary(column_index); - nullable = false; - // Skip other options when using primary key - break; - } - ColumnOption::Unique(_) => column_desc.set_unique(), - ColumnOption::Default(expr) => { - let mut expr = self.bind_expr(expr)?; - - if expr.any_referenced_column(true, |_| true) { - return Err(DatabaseError::UnsupportedStmt( - "column is not allowed to exist in `default`".to_string(), - )); - } - expr = ScalarExpression::type_cast( - expr, - Cow::Borrowed(&column_desc.column_datatype), - )?; - column_desc.default = Some(expr); - } - option => { - return Err(DatabaseError::UnsupportedStmt(format!( - "`Column` does not currently support this option: {option:?}" - ))) - } - } - } - - Ok(ColumnCatalog::new(column_name, nullable, column_desc)) - } } #[cfg(all(test, not(target_arch = "wasm32")))] @@ -188,10 +68,6 @@ mod tests { use crate::storage::Storage; use crate::types::CharLengthUnits; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; - use std::hash::RandomState; - use std::sync::atomic::AtomicUsize; - use std::sync::Arc; use tempfile::TempDir; #[test] @@ -199,8 +75,8 @@ mod tests { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; let transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let table_cache = crate::storage::TableCache::default(); + let view_cache = crate::storage::ViewCache::default(); let scala_functions = Default::default(); let table_functions = Default::default(); @@ -212,13 +88,15 @@ mod tests { &transaction, &scala_functions, &table_functions, - Arc::new(AtomicUsize::new(0)), ), &[], None, ); let stmt = crate::parser::parse_sql(sql).unwrap(); - let plan1 = binder.bind(&stmt[0]).unwrap(); + let stmt = stmt.into_iter().next().unwrap(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let plan1 = binder.bind(&stmt, &mut plan_arena).unwrap(); match plan1.operator { Operator::CreateTable(op) => { diff --git a/src/binder/create_view.rs b/src/binder/create_view.rs index 09c2bbe0..df48f104 100644 --- a/src/binder/create_view.rs +++ b/src/binder/create_view.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, lower_ident, Binder}; +use crate::binder::Binder; use crate::catalog::view::View; use crate::catalog::{ColumnCatalog, ColumnRef, TableName}; use crate::errors::DatabaseError; @@ -22,76 +22,86 @@ use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use itertools::Itertools; -use sqlparser::ast::{ObjectName, Query, ViewColumnDef}; use ulid::Ulid; impl> Binder<'_, '_, T, A> { pub(crate) fn bind_create_view( &mut self, - or_replace: &bool, - name: &ObjectName, - columns: &[ViewColumnDef], - query: &Query, + view_name: TableName, + or_replace: bool, + mut plan: LogicalPlan, + column_names: Vec, + output_aliases: Vec>, + arena: &mut crate::planner::PlanArena, ) -> Result { fn projection_exprs( view_name: &TableName, mapping_schema: &[ColumnRef], - column_names: impl Iterator, + arena: &mut crate::planner::PlanArena, + mut column_name: impl FnMut(usize, ColumnRef, &crate::planner::PlanArena) -> String, ) -> Vec { - column_names - .enumerate() - .map(|(i, column_name)| { - let mapping_column = &mapping_schema[i]; - let mut column = ColumnCatalog::new( - column_name, - mapping_column.nullable(), - mapping_column.desc().clone(), - ); - column.set_ref_table(view_name.clone(), Ulid::new(), true); + let mapping_schema_len = mapping_schema.len(); + let mut exprs = Vec::with_capacity(mapping_schema_len); + for (i, mapping_column) in mapping_schema.iter().copied().enumerate() { + let output_name = column_name(i, mapping_column, arena); + let (nullable, desc) = { + let mapping_column_catalog = arena.column(mapping_column); + ( + mapping_column_catalog.nullable(), + mapping_column_catalog.desc().clone(), + ) + }; + let mut column = ColumnCatalog::new(output_name, nullable, desc); + column.set_ref_table(view_name.clone(), Ulid::new(), true); + let output_column = arena.alloc_column(column); - ScalarExpression::Alias { - expr: Box::new(ScalarExpression::column_expr(mapping_column.clone(), i)), - alias: AliasType::Expr(Box::new(ScalarExpression::column_expr( - ColumnRef::from(column), - i, - ))), - } - }) - .collect_vec() + exprs.push(ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr(mapping_column, i)), + alias: AliasType::Expr(Box::new(ScalarExpression::column_expr( + output_column, + i, + ))), + }); + } + exprs } - let view_name: TableName = lower_case_name(name)?.into(); - let mut plan = self.bind_query(query)?; + let mapping_schema = plan.output_schema(arena); - let mapping_schema = plan.output_schema(); + if !column_names.is_empty() && column_names.len() > mapping_schema.len() { + return Err(DatabaseError::UnsupportedStmt(format!( + "view column count {} exceeds query output count {}", + column_names.len(), + mapping_schema.len() + ))); + } - let exprs: Vec = if columns.is_empty() { - projection_exprs( - &view_name, - mapping_schema, - mapping_schema - .iter() - .map(|column| column.name().to_string()), - ) + let exprs: Vec = if column_names.is_empty() { + projection_exprs(&view_name, mapping_schema, arena, |i, column, arena| { + output_aliases + .get(i) + .and_then(Clone::clone) + .unwrap_or_else(|| arena.column(column).name().to_string()) + }) } else { projection_exprs( &view_name, - mapping_schema, - columns - .iter() - .map(|column| lower_ident(&column.name).into_owned()), + &mapping_schema[..column_names.len()], + arena, + |i, _, _| column_names[i].clone(), ) }; - plan = self.bind_project(plan, exprs)?; + plan = self.bind_project(plan, exprs, arena)?; + let schema = plan.output_schema(arena).clone(); Ok(LogicalPlan::new( Operator::CreateView(CreateViewOperator { view: View { name: view_name, plan: Box::new(plan), + schema, }, - or_replace: *or_replace, + or_replace, }), Childrens::None, )) diff --git a/src/binder/delete.rs b/src/binder/delete.rs index 48993a90..51b732c2 100644 --- a/src/binder/delete.rs +++ b/src/binder/delete.rs @@ -12,50 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, Binder}; -use crate::catalog::TableName; +use crate::binder::Binder; +use crate::catalog::{ColumnRef, TableName}; use crate::errors::DatabaseError; use crate::planner::operator::delete::DeleteOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use itertools::Itertools; -use sqlparser::ast::{Expr, TableFactor, TableWithJoins}; impl> Binder<'_, '_, T, A> { pub(crate) fn bind_delete( &mut self, - from: &TableWithJoins, - selection: &Option, + table_name: TableName, + primary_keys: Vec, + input: LogicalPlan, ) -> Result { - if let TableFactor::Table { name, .. } = &from.relation { - let table_name: TableName = lower_case_name(name)?.into(); - let table = self - .context - .table(table_name.clone())? - .ok_or(DatabaseError::TableNotFound)?; - let primary_keys = table - .primary_keys() - .iter() - .map(|(_, column)| column.clone()) - .collect_vec(); - self.with_pk(table_name.clone()); - let mut plan = self.bind_table_ref(from)?; - - if let Some(predicate) = selection { - plan = self.bind_where(plan, predicate)?; - } - - Ok(LogicalPlan::new( - Operator::Delete(DeleteOperator { - table_name, - primary_keys, - }), - Childrens::Only(Box::new(plan)), - )) - } else { - unreachable!("only table") - } + Ok(LogicalPlan::new( + Operator::Delete(DeleteOperator { + table_name, + primary_keys, + }), + Childrens::Only(Box::new(input)), + )) } } diff --git a/src/binder/describe.rs b/src/binder/describe.rs index aee151ff..440c97f7 100644 --- a/src/binder/describe.rs +++ b/src/binder/describe.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, Binder}; +use crate::binder::Binder; use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::planner::operator::describe::DescribeOperator; @@ -20,15 +20,12 @@ use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use sqlparser::ast::ObjectName; impl> Binder<'_, '_, T, A> { pub(crate) fn bind_describe( &mut self, - name: &ObjectName, + table_name: TableName, ) -> Result { - let table_name: TableName = lower_case_name(name)?.into(); - Ok(LogicalPlan::new( Operator::Describe(DescribeOperator { table_name }), Childrens::None, diff --git a/src/binder/distinct.rs b/src/binder/distinct.rs index b52572c6..cc54bbbd 100644 --- a/src/binder/distinct.rs +++ b/src/binder/distinct.rs @@ -39,11 +39,12 @@ impl> Binder<'_, '_, T, A> } pub fn bind_distinct_output_exprs<'c>( - &self, + &mut self, select_list: &[ScalarExpression], exprs: impl IntoIterator, + arena: &mut crate::planner::PlanArena, ) -> Result<(), DatabaseError> { - let mut binder = DistinctOutputBinder::new(select_list); + let mut binder = DistinctOutputBinder::new(select_list, arena); for expr in exprs { binder.visit(expr)?; } @@ -51,11 +52,12 @@ impl> Binder<'_, '_, T, A> } pub fn bind_distinct_orderby_exprs( - &self, + &mut self, select_list: &[ScalarExpression], orderby: &mut [SortField], + arena: &mut crate::planner::PlanArena, ) -> Result<(), DatabaseError> { - let binder = DistinctOutputBinder::new(select_list); + let mut binder = DistinctOutputBinder::new(select_list, arena); for field in orderby { field.expr = binder.output_ref(&field.expr).ok_or_else(|| { @@ -70,29 +72,36 @@ impl> Binder<'_, '_, T, A> } } -struct DistinctOutputBinder<'a> { +struct DistinctOutputBinder<'a, 'p> { select_list: &'a [ScalarExpression], + arena: &'a mut crate::planner::PlanArena<'p>, } -impl<'a> DistinctOutputBinder<'a> { - fn new(select_list: &'a [ScalarExpression]) -> Self { - Self { select_list } +impl<'a, 'p> DistinctOutputBinder<'a, 'p> { + fn new( + select_list: &'a [ScalarExpression], + arena: &'a mut crate::planner::PlanArena<'p>, + ) -> Self { + Self { select_list, arena } } - fn output_ref(&self, expr: &ScalarExpression) -> Option { + fn output_ref(&mut self, expr: &ScalarExpression) -> Option { self.select_list .iter() .position(|candidate| { - candidate == expr || candidate.unpack_alias_ref() == expr.unpack_alias_ref() + candidate.eq_ignore_colref_pos(expr, self.arena) + || candidate + .unpack_alias_ref() + .eq_ignore_colref_pos(expr.unpack_alias_ref(), self.arena) }) .map(|position| { let output_expr = &self.select_list[position]; - ScalarExpression::column_expr(output_expr.output_column(), position) + ScalarExpression::column_expr(output_expr.output_column_ref(self.arena), position) }) } } -impl<'a> VisitorMut<'a> for DistinctOutputBinder<'_> { +impl<'a> VisitorMut<'a> for DistinctOutputBinder<'_, '_> { fn visit(&mut self, expr: &'a mut ScalarExpression) -> Result<(), DatabaseError> { if let ScalarExpression::Alias { expr: inner_expr, @@ -117,10 +126,11 @@ mod tests { use crate::errors::DatabaseError; use crate::expression::visitor_mut::VisitorMut; use crate::expression::{AliasType, ScalarExpression}; + use crate::planner::PlanArena; use crate::types::LogicalType; - fn test_column(name: &str, ty: LogicalType) -> ColumnRef { - ColumnRef::from(ColumnCatalog::new( + fn test_column(arena: &mut PlanArena, name: &str, ty: LogicalType) -> ColumnRef { + arena.alloc_column(ColumnCatalog::new( name.to_string(), true, ColumnDesc::new(ty, None, false, None).unwrap(), @@ -129,8 +139,10 @@ mod tests { #[test] fn test_distinct_output_binder_rewrites_output_slot() -> Result<(), DatabaseError> { - let left_column = test_column("c1", LogicalType::Integer); - let right_column = test_column("c2", LogicalType::Integer); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); + let left_column = test_column(&mut arena, "c1", LogicalType::Integer); + let right_column = test_column(&mut arena, "c2", LogicalType::Integer); let left_expr = ScalarExpression::column_expr(left_column, 0); let right_expr = ScalarExpression::column_expr(right_column, 1); @@ -141,54 +153,56 @@ mod tests { }; let select_list = [select_output.clone(), right_expr.clone()]; - let mut binder = DistinctOutputBinder::new(&select_list); - let mut order_by_alias = ScalarExpression::Alias { expr: Box::new(left_expr), alias: AliasType::Name("v".to_string()), }; - binder.visit(&mut order_by_alias)?; - assert_eq!( - order_by_alias, - ScalarExpression::Alias { - expr: Box::new(ScalarExpression::column_expr( - select_output.output_column(), - 0 - )), - alias: AliasType::Name("v".to_string()), - } - ); - let mut order_by_second = right_expr; - binder.visit(&mut order_by_second)?; - assert_eq!( - order_by_second, - ScalarExpression::column_expr(second_output.output_column(), 1) - ); + { + let mut binder = DistinctOutputBinder::new(&select_list, &mut arena); + binder.visit(&mut order_by_alias)?; + binder.visit(&mut order_by_second)?; + } + let expected_alias = ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr( + select_output.output_column_ref(&mut arena), + 0, + )), + alias: AliasType::Name("v".to_string()), + }; + assert!(order_by_alias.eq_ignore_colref_pos(&expected_alias, &arena)); + + let expected_second = + ScalarExpression::column_expr(second_output.output_column_ref(&mut arena), 1); + assert!(order_by_second.eq_ignore_colref_pos(&expected_second, &arena)); Ok(()) } #[test] fn test_distinct_output_binder_matches_alias_expr_reference() -> Result<(), DatabaseError> { - let column = test_column("c1", LogicalType::Integer); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); + let column = test_column(&mut arena, "c1", LogicalType::Integer); let expr = ScalarExpression::column_expr(column, 0); let select_output = ScalarExpression::Alias { expr: Box::new(expr.clone()), alias: AliasType::Name("v".to_string()), }; - let mut binder = DistinctOutputBinder::new(std::slice::from_ref(&select_output)); let mut target = ScalarExpression::Alias { expr: Box::new(ScalarExpression::Constant(1_i32.into())), alias: AliasType::Expr(Box::new(expr)), }; - binder.visit(&mut target)?; - assert_eq!( - target, - ScalarExpression::column_expr(select_output.output_column(), 0) - ); + { + let mut binder = + DistinctOutputBinder::new(std::slice::from_ref(&select_output), &mut arena); + binder.visit(&mut target)?; + } + let expected = + ScalarExpression::column_expr(select_output.output_column_ref(&mut arena), 0); + assert!(target.eq_ignore_colref_pos(&expected, &arena)); Ok(()) } diff --git a/src/binder/drop_index.rs b/src/binder/drop_index.rs index 51e2c219..ce138680 100644 --- a/src/binder/drop_index.rs +++ b/src/binder/drop_index.rs @@ -12,34 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{attach_span_if_absent, lower_name_part, Binder}; +use crate::binder::Binder; +use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::planner::operator::drop_index::DropIndexOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use sqlparser::ast::ObjectName; impl> Binder<'_, '_, T, A> { pub(crate) fn bind_drop_index( &mut self, - name: &ObjectName, - if_exists: &bool, + table_name: TableName, + index_name: String, + if_exists: bool, ) -> Result { - let table_name = name.0.first().ok_or_else(|| { - attach_span_if_absent(DatabaseError::invalid_table(name.to_string()), name) - })?; - let index_name = name.0.get(1).ok_or(DatabaseError::InvalidIndex)?; - - let table_name = lower_name_part(table_name)?.into(); - let index_name = lower_name_part(index_name)?; - Ok(LogicalPlan::new( Operator::DropIndex(DropIndexOperator { table_name, - index_name: index_name.into_owned(), - if_exists: *if_exists, + index_name, + if_exists, }), Childrens::None, )) diff --git a/src/binder/drop_table.rs b/src/binder/drop_table.rs index 24b58f62..749911f4 100644 --- a/src/binder/drop_table.rs +++ b/src/binder/drop_table.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, Binder}; +use crate::binder::Binder; use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::planner::operator::drop_table::DropTableOperator; @@ -20,20 +20,17 @@ use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use sqlparser::ast::ObjectName; impl> Binder<'_, '_, T, A> { pub(crate) fn bind_drop_table( &mut self, - name: &ObjectName, - if_exists: &bool, + table_name: TableName, + if_exists: bool, ) -> Result { - let table_name: TableName = lower_case_name(name)?.into(); - Ok(LogicalPlan::new( Operator::DropTable(DropTableOperator { table_name, - if_exists: *if_exists, + if_exists, }), Childrens::None, )) diff --git a/src/binder/drop_view.rs b/src/binder/drop_view.rs index 7e74e995..7cb85d86 100644 --- a/src/binder/drop_view.rs +++ b/src/binder/drop_view.rs @@ -12,27 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, Binder}; +use crate::binder::Binder; +use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::planner::operator::drop_view::DropViewOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use sqlparser::ast::ObjectName; impl> Binder<'_, '_, T, A> { pub(crate) fn bind_drop_view( &mut self, - name: &ObjectName, - if_exists: &bool, + view_name: TableName, + if_exists: bool, ) -> Result { - let view_name = lower_case_name(name)?.into(); - Ok(LogicalPlan::new( Operator::DropView(DropViewOperator { view_name, - if_exists: *if_exists, + if_exists, }), Childrens::None, )) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 56bc48df..9d055022 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -12,32 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::{ColumnCatalog, ColumnRef, TableName}; +use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression; use crate::expression::agg::AggKind; use itertools::Itertools; -use sqlparser::ast::{ - BinaryOperator, DataType, DuplicateTreatment, Expr, Function, FunctionArg, FunctionArgExpr, - FunctionArguments, Ident, Query, TypedString, UnaryOperator, Value, -}; -use std::borrow::Cow; -use std::collections::HashMap; -use std::slice; - -use super::{ - attach_span_from_sqlparser_span_if_absent, attach_span_if_absent, lower_ident, Binder, - BinderContext, QueryBindStep, SubQueryType, -}; + +use super::{Binder, BinderContext, QueryBindStep, SubQueryType}; use crate::expression::function::scala::{ArcScalarFunctionImpl, ScalarFunction}; -use crate::expression::function::table::{ArcTableFunctionImpl, TableFunction}; +use crate::expression::function::table::TableFunction; use crate::expression::function::FunctionSummary; use crate::expression::{AliasType, ScalarExpression}; use crate::planner::operator::mark_apply::MarkApplyQuantifier; use crate::planner::operator::scalar_subquery::ScalarSubqueryOperator; -use crate::planner::{LogicalPlan, SchemaOutput}; +use crate::planner::{LogicalPlan, PlanArena}; use crate::storage::Transaction; -use crate::types::tuple::SchemaRef; use crate::types::value::{DataValue, Utf8Type}; use crate::types::{CharLengthUnits, ColumnId, LogicalType}; @@ -50,375 +39,47 @@ macro_rules! try_default { } impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T, A> { - fn parse_like_escape_char(escape_char: &Option) -> Result, DatabaseError> { - match escape_char { - None => Ok(None), - Some(value) => match value { - Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => { - let mut chars = s.chars(); - let ch = chars.next().ok_or(DatabaseError::InvalidValue( - "escape character must not be empty".to_string(), - ))?; - if chars.next().is_some() { - return Err(DatabaseError::InvalidValue( - "escape character must be a single character".to_string(), - )); - } - Ok(Some(ch)) - } - _ => Err(DatabaseError::InvalidValue( - "escape character must be a quoted string".to_string(), - )), - }, - } - } - - fn find_column_in_schema<'b>( - schema_ref: &'b SchemaRef, + fn find_column_in_schema<'schema>( + schema_ref: impl IntoIterator, + arena: &PlanArena, column_name: &str, - ) -> Option<(usize, &'b ColumnRef)> { + ) -> Option<(usize, ColumnRef)> { schema_ref - .iter() + .into_iter() .enumerate() - .find(|(_, column)| column.name() == column_name) + .find(|(_, column)| arena.column(**column).name() == column_name) + .map(|(position, column)| (position, *column)) } fn find_column_in_scope( context: &BinderContext<'a, T>, - table_schema_buf: &mut HashMap>, + arena: &mut PlanArena, column_name: &str, ) -> Option { let mut position_offset = 0; for bound_source in &context.bind_table { - let schema_buf = table_schema_buf - .entry(bound_source.table_name.clone()) - .or_default(); - let schema_ref = bound_source.source.schema_ref(schema_buf); + let source = &bound_source.source; - if let Some((position, column)) = Self::find_column_in_schema(&schema_ref, column_name) + if let Some((position, column)) = + Self::find_column_in_schema(source.schema().iter(), arena, column_name) { return Some(ScalarExpression::column_expr( - column.clone(), + column, position_offset + position, )); } - position_offset += schema_ref.len(); + position_offset += source.schema_len(); } None } - - fn column_not_found_with_span(idents: &[Ident], column_name: &str) -> DatabaseError { - let err = DatabaseError::column_not_found(column_name.to_string()); - match idents.last() { - Some(ident) => attach_span_from_sqlparser_span_if_absent(err, ident.span), - None => err, - } - } - - pub(crate) fn bind_expr(&mut self, expr: &Expr) -> Result { - match expr { - Expr::Identifier(ident) => { - self.bind_column_ref_from_identifiers(slice::from_ref(ident), None) - } - Expr::CompoundIdentifier(idents) => self.bind_column_ref_from_identifiers(idents, None), - Expr::BinaryOp { left, right, op } => self.bind_binary_op_internal(left, right, op), - Expr::Value(v) => { - let value = if let Value::Placeholder(name) = &v.value { - self.args - .as_ref() - .iter() - .find_map(|(key, value)| (key == name).then(|| value.clone())) - .ok_or_else(|| { - attach_span_if_absent( - DatabaseError::parameter_not_found(name.to_string()), - v, - ) - })? - } else { - (&v.value) - .try_into() - .map_err(|err| attach_span_if_absent(err, v))? - }; - Ok(ScalarExpression::Constant(value)) - } - Expr::Function(func) => self.bind_function(func), - Expr::Nested(expr) => self.bind_expr(expr), - Expr::UnaryOp { expr, op } => self.bind_unary_op_internal(expr, op), - Expr::Like { - negated, - expr, - pattern, - escape_char, - any: _, - } => self.bind_like(*negated, expr, pattern, escape_char), - Expr::IsNull(expr) => self.bind_is_null(expr, false), - Expr::IsNotNull(expr) => self.bind_is_null(expr, true), - Expr::InList { - expr, - list, - negated, - } => self.bind_is_in(expr, list, *negated), - Expr::Cast { - expr, data_type, .. - } => self.bind_cast(expr, data_type), - Expr::TypedString(TypedString { - data_type, value, .. - }) => { - let logical_type = LogicalType::try_from(data_type.clone())?; - let raw = value.clone().into_string().ok_or_else(|| { - DatabaseError::InvalidValue("typed string literal must be a string".to_string()) - })?; - let value = DataValue::Utf8 { - value: raw, - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - } - .cast(&logical_type) - .map_err(|err| attach_span_if_absent(err, expr))?; - - Ok(ScalarExpression::Constant(value)) - } - Expr::Between { - expr, - negated, - low, - high, - } => Ok(ScalarExpression::Between { - negated: *negated, - expr: Box::new(self.bind_expr(expr)?), - left_expr: Box::new(self.bind_expr(low)?), - right_expr: Box::new(self.bind_expr(high)?), - }), - Expr::Substring { - expr, - substring_for, - substring_from, - .. - } => { - let mut for_expr = None; - let mut from_expr = None; - - if let Some(expr) = substring_for { - for_expr = Some(Box::new(self.bind_expr(expr)?)) - } - if let Some(expr) = substring_from { - from_expr = Some(Box::new(self.bind_expr(expr)?)) - } - - Ok(ScalarExpression::SubString { - expr: Box::new(self.bind_expr(expr)?), - for_expr, - from_expr, - }) - } - Expr::Position { expr, r#in } => Ok(ScalarExpression::Position { - expr: Box::new(self.bind_expr(expr)?), - in_expr: Box::new(self.bind_expr(r#in)?), - }), - Expr::Trim { - expr, - trim_what, - trim_where, - .. - } => { - let mut trim_what_expr = None; - if let Some(trim_what) = trim_what { - trim_what_expr = Some(Box::new(self.bind_expr(trim_what)?)) - } - Ok(ScalarExpression::Trim { - expr: Box::new(self.bind_expr(expr)?), - trim_what_expr, - trim_where: trim_where.map(Into::into), - }) - } - Expr::Exists { subquery, negated } => { - let (sub_query, correlated) = self.bind_subquery(subquery)?; - let (_, marker_ref) = self - .bind_temp_table_alias(ScalarExpression::Constant(DataValue::Boolean(true)), 0); - self.context.sub_query(SubQueryType::ExistsSubQuery { - plan: sub_query, - correlated, - output_column: marker_ref.output_column(), - }); - if *negated { - Ok(ScalarExpression::Unary { - op: expression::UnaryOperator::Not, - expr: Box::new(marker_ref), - evaluator: None, - ty: LogicalType::Boolean, - }) - } else { - Ok(marker_ref) - } - } - Expr::Subquery(subquery) => { - let (sub_query, column, correlated) = - self.bind_subquery_with_output(None, subquery)?; - let sub_query = ScalarSubqueryOperator::build(sub_query); - let (expr, sub_query) = if !self.context.is_step(&QueryBindStep::Where) { - self.bind_temp_table(column, sub_query)? - } else { - (column, sub_query) - }; - self.context.sub_query(SubQueryType::SubQuery { - plan: sub_query, - correlated, - }); - Ok(expr) - } - Expr::InSubquery { - expr, - subquery, - negated, - } => self.bind_quantified_subquery( - MarkApplyQuantifier::Any, - *negated, - expr, - &BinaryOperator::Eq, - subquery, - ), - Expr::Tuple(exprs) => { - let mut bond_exprs = Vec::with_capacity(exprs.len()); - - for expr in exprs { - bond_exprs.push(self.bind_expr(expr)?); - } - Ok(ScalarExpression::Tuple(bond_exprs)) - } - Expr::Case { - operand, - conditions, - else_result, - .. - } => { - let fn_check_ty = |ty: &mut LogicalType, result_ty| { - if result_ty != LogicalType::SqlNull { - if ty == &LogicalType::SqlNull { - *ty = result_ty; - } else if ty != &result_ty { - return Err(DatabaseError::Incomparable(ty.clone(), result_ty)); - } - } - - Ok(()) - }; - let mut operand_expr = None; - let mut ty = LogicalType::SqlNull; - if let Some(expr) = operand { - operand_expr = Some(Box::new(self.bind_expr(expr)?)); - } - let mut expr_pairs = Vec::with_capacity(conditions.len()); - for when in conditions { - let result = self.bind_expr(&when.result)?; - let result_ty = result.return_type().into_owned(); - - fn_check_ty(&mut ty, result_ty)?; - expr_pairs.push((self.bind_expr(&when.condition)?, result)) - } - - let mut else_expr = None; - if let Some(expr) = else_result { - let temp_expr = Box::new(self.bind_expr(expr)?); - let else_ty = temp_expr.return_type().into_owned(); - - fn_check_ty(&mut ty, else_ty)?; - else_expr = Some(temp_expr); - } - - Ok(ScalarExpression::CaseWhen { - operand_expr, - expr_pairs, - else_expr, - ty, - }) - } - Expr::AnyOp { - left, - compare_op, - right, - .. - } => self.bind_quantified_op(MarkApplyQuantifier::Any, left, compare_op, right), - Expr::AllOp { - left, - compare_op, - right, - } => self.bind_quantified_op(MarkApplyQuantifier::All, left, compare_op, right), - expr => Err(DatabaseError::UnsupportedStmt(expr.to_string())), - } - } - - fn bind_quantified_op( - &mut self, - quantifier: MarkApplyQuantifier, - left: &Expr, - compare_op: &BinaryOperator, - right: &Expr, - ) -> Result { - let Expr::Subquery(subquery) = right else { - return Err(DatabaseError::UnsupportedStmt(format!( - "{quantifier:?} only supports subquery operands" - ))); - }; - - self.bind_quantified_subquery(quantifier, false, left, compare_op, subquery) - } - - fn bind_quantified_subquery( - &mut self, - quantifier: MarkApplyQuantifier, - negated: bool, - expr: &Expr, - compare_op: &BinaryOperator, - subquery: &Query, - ) -> Result { - let left_expr = self.bind_expr(expr)?; - let (sub_query, column, correlated) = - self.bind_subquery_with_output(Some(left_expr.return_type().as_ref()), subquery)?; - - if !self.context.is_step(&QueryBindStep::Where) { - return Err(DatabaseError::UnsupportedStmt( - "quantified subqueries can only appear in `WHERE`".to_string(), - )); - } - - let (alias_expr, sub_query) = self.bind_temp_table(column, sub_query)?; - let predicate = ScalarExpression::Binary { - op: (*compare_op).clone().try_into()?, - left_expr: Box::new(left_expr), - right_expr: Box::new(alias_expr), - evaluator: None, - ty: LogicalType::Boolean, - }; - let (_, marker_ref) = - self.bind_temp_table_alias(ScalarExpression::Constant(DataValue::Boolean(true)), 0); - self.context.sub_query(SubQueryType::QuantifiedSubQuery { - quantifier, - negated, - plan: sub_query, - correlated, - output_column: marker_ref.output_column(), - predicate, - }); - - if negated { - Ok(ScalarExpression::Unary { - op: expression::UnaryOperator::Not, - expr: Box::new(marker_ref), - evaluator: None, - ty: LogicalType::Boolean, - }) - } else { - Ok(marker_ref) - } - } - - fn bind_temp_table( + pub(crate) fn bind_temp_table( &mut self, expr: ScalarExpression, sub_query: LogicalPlan, + arena: &mut PlanArena, ) -> Result<(ScalarExpression, LogicalPlan), DatabaseError> { let (exprs, is_tuple) = match expr { ScalarExpression::Tuple(exprs) => (exprs, true), @@ -428,7 +89,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let mut alias_refs = Vec::with_capacity(exprs.len()); for (position, expr) in exprs.into_iter().enumerate() { - let (alias_expr, alias_ref) = self.bind_temp_table_alias(expr, position); + let (alias_expr, alias_ref) = self.bind_temp_table_alias(expr, position, arena); if !is_tuple { let alias_plan = Self::build_project_plan(sub_query, vec![alias_expr.clone()]); return Ok((alias_expr, alias_plan)); @@ -441,15 +102,18 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T Ok((ScalarExpression::Tuple(alias_refs), alias_plan)) } - fn bind_temp_table_alias( + pub(crate) fn bind_temp_table_alias( &mut self, expr: ScalarExpression, position: usize, + arena: &mut PlanArena, ) -> (ScalarExpression, ScalarExpression) { - let mut alias_column = ColumnCatalog::clone(&expr.output_column()); - alias_column.set_ref_table(self.context.temp_table(), ColumnId::new(), true); + let output_column = expr.output_column_ref(arena); + let mut alias_column = arena.clone_column(output_column); + alias_column.set_ref_table(arena.temp_table(), ColumnId::new(), true); - let alias_ref = ScalarExpression::column_expr(ColumnRef::from(alias_column), position); + let alias_column = arena.alloc_column(alias_column); + let alias_ref = ScalarExpression::column_expr(alias_column, position); ( ScalarExpression::Alias { expr: Box::new(expr), @@ -459,13 +123,55 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T ) } - fn bind_subquery_with_output( + pub(crate) fn bind_subquery_plan<'arena, F>( + &mut self, + arena: &mut PlanArena<'arena>, + build: F, + ) -> Result<(LogicalPlan, bool), DatabaseError> + where + F: FnOnce( + &mut Binder<'a, '_, T, A>, + &mut PlanArena<'arena>, + ) -> Result, + { + let BinderContext { + table_cache, + view_cache, + transaction, + scala_functions, + table_functions, + .. + } = &self.context; + let mut binder = Binder::new( + BinderContext::new( + table_cache, + view_cache, + *transaction, + scala_functions, + table_functions, + ), + self.args, + Some(&self.context), + ); + let sub_query = build(&mut binder, arena)?; + let correlated = binder.context.has_outer_refs(); + Ok((sub_query, correlated)) + } + + pub(crate) fn bind_subquery_plan_with_output<'arena, F>( &mut self, value_ty: Option<&LogicalType>, - subquery: &Query, - ) -> Result<(LogicalPlan, ScalarExpression, bool), DatabaseError> { - let (mut sub_query, correlated) = self.bind_subquery(subquery)?; - let sub_query_schema = sub_query.output_schema(); + arena: &mut PlanArena<'arena>, + build: F, + ) -> Result<(LogicalPlan, ScalarExpression, bool), DatabaseError> + where + F: FnOnce( + &mut Binder<'a, '_, T, A>, + &mut PlanArena<'arena>, + ) -> Result, + { + let (mut sub_query, correlated) = self.bind_subquery_plan(arena, build)?; + let sub_query_schema = sub_query.output_schema(arena); let fn_check = |len: usize| { if sub_query_schema.len() != len { @@ -483,141 +189,191 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let columns = sub_query_schema .iter() .enumerate() - .map(|(position, column)| ScalarExpression::column_expr(column.clone(), position)) + .map(|(position, column)| ScalarExpression::column_expr(*column, position)) .collect::>(); ScalarExpression::Tuple(columns) } else { fn_check(1)?; - ScalarExpression::column_expr(sub_query_schema[0].clone(), 0) + ScalarExpression::column_expr(sub_query_schema[0], 0) }; Ok((sub_query, expr, correlated)) } - fn bind_subquery(&mut self, subquery: &Query) -> Result<(LogicalPlan, bool), DatabaseError> { - let BinderContext { - table_cache, - view_cache, - transaction, - scala_functions, - table_functions, - temp_table_id, - .. - } = &self.context; - let mut binder = Binder::new( - BinderContext::new( - table_cache, - view_cache, - *transaction, - scala_functions, - table_functions, - temp_table_id.clone(), - ), - self.args, - Some(self), - ); - let sub_query = binder.bind_query(subquery)?; - let correlated = binder.context.has_outer_refs(); - Ok((sub_query, correlated)) + pub(crate) fn bind_scalar_subquery_plan<'arena, F>( + &mut self, + arena: &mut PlanArena<'arena>, + build: F, + ) -> Result + where + F: FnOnce( + &mut Binder<'a, '_, T, A>, + &mut PlanArena<'arena>, + ) -> Result, + { + let (sub_query, column, correlated) = + self.bind_subquery_plan_with_output(None, arena, build)?; + let sub_query = ScalarSubqueryOperator::build(sub_query); + let (expr, sub_query) = match self.context.step_now() { + QueryBindStep::Where => (column, sub_query), + QueryBindStep::Project => self.bind_temp_table(column, sub_query, arena)?, + _ => { + return Err(DatabaseError::UnsupportedStmt( + "scalar subqueries can only appear in `WHERE` or SELECT list".to_string(), + )) + } + }; + self.context.sub_query(SubQueryType::SubQuery { + plan: sub_query, + correlated, + }); + Ok(expr) } - pub fn bind_like( + pub(crate) fn bind_exists_subquery_plan<'arena, F>( &mut self, negated: bool, - expr: &Expr, - pattern: &Expr, - escape_char: &Option, - ) -> Result { - let left_expr = Box::new(self.bind_expr(expr)?); - let right_expr = Box::new(self.bind_expr(pattern)?); - let escape_char = Self::parse_like_escape_char(escape_char)?; - let op = if negated { - expression::BinaryOperator::NotLike(escape_char) + arena: &mut PlanArena<'arena>, + build: F, + ) -> Result + where + F: FnOnce( + &mut Binder<'a, '_, T, A>, + &mut PlanArena<'arena>, + ) -> Result, + { + if !self.context.is_step(&QueryBindStep::Where) { + return Err(DatabaseError::UnsupportedStmt( + "EXISTS subqueries can only appear in `WHERE`".to_string(), + )); + } + + let (sub_query, correlated) = self.bind_subquery_plan(arena, build)?; + let (_, marker_ref) = self.bind_temp_table_alias( + ScalarExpression::Constant(DataValue::Boolean(true)), + 0, + arena, + ); + let output_column = marker_ref.output_column_ref(arena); + self.context.sub_query(SubQueryType::ExistsSubQuery { + plan: sub_query, + correlated, + output_column, + }); + if negated { + Ok(ScalarExpression::Unary { + op: expression::UnaryOperator::Not, + expr: Box::new(marker_ref), + evaluator: None, + ty: LogicalType::Boolean, + }) } else { - expression::BinaryOperator::Like(escape_char) - }; - Ok(ScalarExpression::Binary { - op, - left_expr, - right_expr, + Ok(marker_ref) + } + } + + pub(crate) fn bind_quantified_subquery_plan<'arena, F>( + &mut self, + quantifier: MarkApplyQuantifier, + negated: bool, + left_expr: ScalarExpression, + compare_op: expression::BinaryOperator, + arena: &mut PlanArena<'arena>, + build: F, + ) -> Result + where + F: FnOnce( + &mut Binder<'a, '_, T, A>, + &mut PlanArena<'arena>, + ) -> Result, + { + let left_ty = left_expr.return_type(arena).into_owned(); + let (sub_query, column, correlated) = + self.bind_subquery_plan_with_output(Some(&left_ty), arena, build)?; + + if !self.context.is_step(&QueryBindStep::Where) { + return Err(DatabaseError::UnsupportedStmt( + "quantified subqueries can only appear in `WHERE`".to_string(), + )); + } + + let (alias_expr, sub_query) = self.bind_temp_table(column, sub_query, arena)?; + let predicate = ScalarExpression::Binary { + op: compare_op, + left_expr: Box::new(left_expr), + right_expr: Box::new(alias_expr), evaluator: None, ty: LogicalType::Boolean, - }) + }; + let (_, marker_ref) = self.bind_temp_table_alias( + ScalarExpression::Constant(DataValue::Boolean(true)), + 0, + arena, + ); + let output_column = marker_ref.output_column_ref(arena); + self.context.sub_query(SubQueryType::QuantifiedSubQuery { + quantifier, + negated, + plan: sub_query, + correlated, + output_column, + predicate, + }); + + if negated { + Ok(ScalarExpression::Unary { + op: expression::UnaryOperator::Not, + expr: Box::new(marker_ref), + evaluator: None, + ty: LogicalType::Boolean, + }) + } else { + Ok(marker_ref) + } } - pub fn bind_column_ref_from_identifiers( + pub(crate) fn bind_column_ref_by_name( &mut self, - idents: &[Ident], + table_name: Option<&str>, + column_name: &str, bind_table_name: Option<&str>, + arena: &mut PlanArena, ) -> Result { - let full_name = match idents { - [column] => (None, lower_ident(column)), - [table, column] => (Some(lower_ident(table)), lower_ident(column)), - _ => { - let invalid_name = idents - .iter() - .map(|ident| ident.value.clone()) - .join(".") - .to_string(); - let err = DatabaseError::invalid_column(invalid_name); - return Err(match idents.last() { - Some(ident) => attach_span_from_sqlparser_span_if_absent(err, ident.span), - None => err, - }); - } - }; - if full_name.0.is_none() { + if table_name.is_none() { if let Some((_, expr)) = self .context .expr_aliases .iter() - .find(|((table, column), _)| table.is_none() && column == full_name.1.as_ref()) + .find(|((table, column), _)| table.is_none() && column == column_name) { return Ok(ScalarExpression::Alias { expr: Box::new(expr.clone()), - alias: AliasType::Name(full_name.1.into_owned()), + alias: AliasType::Name(column_name.to_string()), }); } } if self.context.allow_default { - try_default!(&full_name.0, full_name.1); + try_default!(&table_name, column_name); } - if let Some(table) = full_name.0.as_deref().or(bind_table_name) { - let (schema_ref, position_offset) = match Self::resolve_source_columns_in_scope( - &self.context, - &mut self.table_schema_buf, - &table, - ) { - Ok(source) => source, - Err(err) => { - if let Some(parent) = self.parent { - self.context.mark_outer_ref(); - Self::resolve_source_columns_in_scope( - &parent.context, - &mut self.table_schema_buf, - &table, - ) - .map_err(|_| { - if let [table_ident, _] = idents { - attach_span_from_sqlparser_span_if_absent(err, table_ident.span) - } else { - err - } - })? - } else { - return Err(if let [table_ident, _] = idents { - attach_span_from_sqlparser_span_if_absent(err, table_ident.span) + if let Some(table) = table_name.or(bind_table_name) { + let (source, position_offset) = + match Self::resolve_source_columns_in_scope(&self.context, table) { + Ok(source) => source, + Err(err) => { + if let Some(parent) = self.parent { + self.context.mark_outer_ref(); + Self::resolve_source_columns_in_scope(parent, table).map_err(|_| err)? } else { - err - }); + return Err(err); + } } - } - }; - let (position, column) = Self::find_column_in_schema(&schema_ref, full_name.1.as_ref()) - .ok_or_else(|| Self::column_not_found_with_span(idents, full_name.1.as_ref()))?; + }; + let (position, column) = + Self::find_column_in_schema(source.schema().iter(), arena, column_name) + .ok_or_else(|| DatabaseError::column_not_found(column_name.to_string()))?; Ok(ScalarExpression::column_expr( - column.clone(), + column, position_offset + position, )) } else { @@ -626,53 +382,46 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T |context: &BinderContext<'a, T>| -> Result, DatabaseError> { Ok(context .using - .get(full_name.1.as_ref()) - .map(|using_column| using_column.visible_expr()) + .get(column_name) + .map(|using_column| using_column.visible_expr(arena)) .transpose()? .or_else(|| { - Self::find_column_in_scope( - context, - &mut self.table_schema_buf, - full_name.1.as_ref(), - ) + Self::find_column_in_scope(context, arena, column_name) })) }; let mut got_column = find_visible_column(&self.context)?; if got_column.is_none() { if let Some(parent) = self.parent { self.context.mark_outer_ref(); - got_column = find_visible_column(&parent.context)?; + got_column = find_visible_column(parent)?; } } match got_column { Some(column) => Ok(column), - None => Err(Self::column_not_found_with_span( - idents, - full_name.1.as_ref(), - )), + None => Err(DatabaseError::column_not_found(column_name.to_string())), } } } - fn bind_binary_op_internal( + pub(crate) fn bind_binary_op_expr( &mut self, - left: &Expr, - right: &Expr, - op: &BinaryOperator, + left_expr: ScalarExpression, + right_expr: ScalarExpression, + op: expression::BinaryOperator, + arena: &mut PlanArena, ) -> Result { - let left_expr = Box::new(self.bind_expr(left)?); - let right_expr = Box::new(self.bind_expr(right)?); - - let left_ty = left_expr.return_type(); - let right_ty = right_expr.return_type(); - let ty = match op { - BinaryOperator::Plus - | BinaryOperator::Minus - | BinaryOperator::Multiply - | BinaryOperator::Modulo => { + let left_expr = Box::new(left_expr); + let right_expr = Box::new(right_expr); + let left_ty = left_expr.return_type(arena); + let right_ty = right_expr.return_type(arena); + let ty = match &op { + expression::BinaryOperator::Plus + | expression::BinaryOperator::Minus + | expression::BinaryOperator::Multiply + | expression::BinaryOperator::Modulo => { LogicalType::max_logical_type(&left_ty, &right_ty)?.into_owned() } - BinaryOperator::Divide => { + expression::BinaryOperator::Divide => { if let LogicalType::Decimal(precision, scale) = LogicalType::max_logical_type(&left_ty, &right_ty)?.into_owned() { @@ -681,21 +430,24 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T LogicalType::Double } } - BinaryOperator::Gt - | BinaryOperator::Lt - | BinaryOperator::GtEq - | BinaryOperator::LtEq - | BinaryOperator::Eq - | BinaryOperator::NotEq - | BinaryOperator::And - | BinaryOperator::Or - | BinaryOperator::Xor => LogicalType::Boolean, - BinaryOperator::StringConcat => LogicalType::Varchar(None, CharLengthUnits::Characters), + expression::BinaryOperator::Gt + | expression::BinaryOperator::Lt + | expression::BinaryOperator::GtEq + | expression::BinaryOperator::LtEq + | expression::BinaryOperator::Eq + | expression::BinaryOperator::NotEq + | expression::BinaryOperator::Like(_) + | expression::BinaryOperator::NotLike(_) + | expression::BinaryOperator::And + | expression::BinaryOperator::Or => LogicalType::Boolean, + expression::BinaryOperator::StringConcat => { + LogicalType::Varchar(None, CharLengthUnits::Characters) + } op => return Err(DatabaseError::UnsupportedStmt(format!("{op}"))), }; Ok(ScalarExpression::Binary { - op: (op.clone()).try_into()?, + op, left_expr, right_expr, evaluator: None, @@ -703,59 +455,34 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T }) } - fn bind_unary_op_internal( + pub(crate) fn bind_unary_op_expr( &mut self, - expr: &Expr, - op: &UnaryOperator, + expr: ScalarExpression, + op: expression::UnaryOperator, + arena: &mut PlanArena, ) -> Result { - let expr = Box::new(self.bind_expr(expr)?); - let ty = if let UnaryOperator::Not = op { + let expr = Box::new(expr); + let ty = if let expression::UnaryOperator::Not = op { LogicalType::Boolean } else { - expr.return_type().into_owned() + expr.return_type(arena).into_owned() }; Ok(ScalarExpression::Unary { - op: (*op).try_into()?, + op, expr, evaluator: None, ty, }) } - fn bind_function(&mut self, func: &Function) -> Result { - let (func_args, is_distinct) = match &func.args { - FunctionArguments::List(args) => ( - args.args.as_slice(), - matches!(args.duplicate_treatment, Some(DuplicateTreatment::Distinct)), - ), - FunctionArguments::None => (&[][..], false), - FunctionArguments::Subquery(_) => { - return Err(DatabaseError::UnsupportedStmt( - "subquery function args are not supported".to_string(), - )) - } - }; - let mut args = Vec::with_capacity(func_args.len()); - - for arg in func_args { - let arg_expr = match arg { - FunctionArg::Named { arg, .. } => arg, - FunctionArg::ExprNamed { arg, .. } => arg, - FunctionArg::Unnamed(arg) => arg, - }; - match arg_expr { - FunctionArgExpr::Expr(expr) => args.push(self.bind_expr(expr)?), - FunctionArgExpr::Wildcard => args.push(Self::wildcard_expr()), - expr => { - return Err(DatabaseError::UnsupportedStmt(format!( - "function arg: {expr:#?}" - ))) - } - } - } - let function_name = func.name.to_string().to_lowercase(); - + pub(crate) fn bind_function_call( + &mut self, + function_name: String, + mut args: Vec, + is_distinct: bool, + arena: &mut PlanArena, + ) -> Result { match function_name.as_str() { "count" => { if args.len() != 1 { @@ -772,7 +499,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T if args.len() != 1 { return Err(DatabaseError::MisMatch("number of sum() parameters", "1")); } - let ty = args[0].return_type().into_owned(); + let ty = args[0].return_type(arena).into_owned(); return Ok(ScalarExpression::AggCall { distinct: is_distinct, @@ -785,7 +512,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T if args.len() != 1 { return Err(DatabaseError::MisMatch("number of min() parameters", "1")); } - let ty = args[0].return_type().into_owned(); + let ty = args[0].return_type(arena).into_owned(); return Ok(ScalarExpression::AggCall { distinct: is_distinct, @@ -798,7 +525,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T if args.len() != 1 { return Err(DatabaseError::MisMatch("number of max() parameters", "1")); } - let ty = args[0].return_type().into_owned(); + let ty = args[0].return_type(arena).into_owned(); return Ok(ScalarExpression::AggCall { distinct: is_distinct, @@ -823,7 +550,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T if args.len() != 3 { return Err(DatabaseError::MisMatch("number of if() parameters", "3")); } - let ty = Self::return_type(&args[1], &args[2])?; + let ty = Self::return_type(&args[1], &args[2], arena)?; let right_expr = Box::new(args.pop().unwrap()); let left_expr = Box::new(args.pop().unwrap()); let condition = Box::new(args.pop().unwrap()); @@ -842,7 +569,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T "3", )); } - let ty = Self::return_type(&args[0], &args[1])?; + let ty = Self::return_type(&args[0], &args[1], arena)?; let right_expr = Box::new(args.pop().unwrap()); let left_expr = Box::new(args.pop().unwrap()); @@ -859,7 +586,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T "3", )); } - let ty = Self::return_type(&args[0], &args[1])?; + let ty = Self::return_type(&args[0], &args[1], arena)?; let right_expr = Box::new(args.pop().unwrap()); let left_expr = Box::new(args.pop().unwrap()); @@ -873,10 +600,10 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T let mut ty = LogicalType::SqlNull; if !args.is_empty() { - ty = args[0].return_type().into_owned(); + ty = args[0].return_type(arena).into_owned(); for arg in args.iter_mut() { - let temp_ty = arg.return_type().into_owned(); + let temp_ty = arg.return_type(arena).into_owned(); if temp_ty == LogicalType::SqlNull { continue; @@ -894,7 +621,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } let arg_types = args .iter() - .map(|arg| arg.return_type().into_owned()) + .map(|arg| arg.return_type(arena).into_owned()) .collect_vec(); let summary = FunctionSummary { name: function_name.into(), @@ -914,22 +641,20 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } return Ok(ScalarExpression::TableFunction(TableFunction { args, - inner: ArcTableFunctionImpl(function.clone()), + catalog: function.clone(), })); } - Err(attach_span_if_absent( - DatabaseError::function_not_found(summary.name.to_string()), - func, - )) + Err(DatabaseError::function_not_found(summary.name.to_string())) } - fn return_type( + pub(crate) fn return_type( expr_1: &ScalarExpression, expr_2: &ScalarExpression, + arena: &PlanArena, ) -> Result { - let temp_ty_1 = expr_1.return_type(); - let temp_ty_2 = expr_2.return_type(); + let temp_ty_1 = expr_1.return_type(arena); + let temp_ty_2 = expr_2.return_type(arena); match (temp_ty_1.as_ref(), temp_ty_2.as_ref()) { (LogicalType::SqlNull, LogicalType::SqlNull) => Ok(LogicalType::SqlNull), @@ -938,40 +663,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T } } - fn bind_is_null( - &mut self, - expr: &Expr, - negated: bool, - ) -> Result { - Ok(ScalarExpression::IsNull { - negated, - expr: Box::new(self.bind_expr(expr)?), - }) - } - - fn bind_is_in( - &mut self, - expr: &Expr, - list: &[Expr], - negated: bool, - ) -> Result { - let args = list.iter().map(|expr| self.bind_expr(expr)).try_collect()?; - - Ok(ScalarExpression::In { - negated, - expr: Box::new(self.bind_expr(expr)?), - args, - }) - } - - fn bind_cast(&mut self, expr: &Expr, ty: &DataType) -> Result { - ScalarExpression::type_cast( - self.bind_expr(expr)?, - Cow::Owned(LogicalType::try_from(ty.clone())?), - ) - } - - fn wildcard_expr() -> ScalarExpression { + pub(crate) fn wildcard_expr() -> ScalarExpression { ScalarExpression::Constant(DataValue::Utf8 { value: "*".to_string(), ty: Utf8Type::Variable(None), diff --git a/src/binder/insert.rs b/src/binder/insert.rs index bfc507e7..120a81ea 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -12,121 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{ - attach_span_from_sqlparser_span_if_absent, attach_span_if_absent, lower_case_name, lower_ident, - Binder, -}; +use crate::binder::Binder; use crate::catalog::TableName; use crate::errors::DatabaseError; -use crate::expression::simplify::ConstantCalculator; -use crate::expression::visitor_mut::VisitorMut; -use crate::expression::AliasType; -use crate::expression::ScalarExpression; use crate::planner::operator::insert::InsertOperator; use crate::planner::operator::values::ValuesOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; -use crate::types::tuple::SchemaRef; +use crate::types::tuple::Schema; use crate::types::value::DataValue; -use sqlparser::ast::{Expr, Ident, ObjectName, Query}; -use std::borrow::Cow; -use std::slice; -use std::sync::Arc; impl> Binder<'_, '_, T, A> { - pub(crate) fn bind_insert( + pub(crate) fn bind_insert_values( &mut self, - name: &ObjectName, - idents: &[Ident], - expr_rows: &Vec>, + table_name: TableName, + schema_ref: Schema, + rows: Vec>, is_overwrite: bool, is_mapping_by_name: bool, ) -> Result { - // FIXME: Make it better to detect the current BindStep - self.context.allow_default = true; - let table_name: TableName = lower_case_name(name)?.into(); - - let source = self - .context - .source_and_bind(table_name.clone(), None, None, false)? - .ok_or(DatabaseError::TableNotFound)?; - let mut _schema_ref = None; - let values_len = expr_rows[0].len(); - - if idents.is_empty() { - let schema_buf = self.table_schema_buf.entry(table_name.clone()).or_default(); - let temp_schema_ref = source.schema_ref(schema_buf); - if values_len > temp_schema_ref.len() { - return Err(DatabaseError::ValuesLenMismatch( - temp_schema_ref.len(), - values_len, - )); - } - _schema_ref = Some(temp_schema_ref); - } else { - let mut columns = Vec::with_capacity(idents.len()); - for ident in idents { - match self.bind_column_ref_from_identifiers( - slice::from_ref(ident), - Some(table_name.as_ref()), - )? { - ScalarExpression::ColumnRef { column, .. } => columns.push(column), - _ => return Err(DatabaseError::UnsupportedStmt(ident.to_string())), - } - } - if values_len != columns.len() { - return Err(DatabaseError::ValuesLenMismatch(columns.len(), values_len)); - } - _schema_ref = Some(Arc::new(columns)); - } - let schema_ref = _schema_ref.ok_or(DatabaseError::ColumnsEmpty)?; - let mut rows = Vec::with_capacity(expr_rows.len()); - - for expr_row in expr_rows { - if expr_row.len() != values_len { - return Err(DatabaseError::ValuesLenMismatch(expr_row.len(), values_len)); - } - let mut row = Vec::with_capacity(expr_row.len()); - - for (i, expr) in expr_row.iter().enumerate() { - let mut expression = self.bind_expr(expr)?; - - ConstantCalculator.visit(&mut expression)?; - match expression { - ScalarExpression::Constant(mut value) => { - let ty = schema_ref[i].datatype(); - - value = value.cast(ty)?; - // Check if the value length is too long - value.check_len(ty)?; - if value.is_null() && !schema_ref[i].nullable() { - return Err(attach_span_if_absent( - DatabaseError::not_null_column(schema_ref[i].name().to_string()), - expr, - )); - } - - row.push(value); - } - ScalarExpression::Empty => { - let default_value = schema_ref[i] - .default_value()? - .ok_or(DatabaseError::DefaultNotExist)?; - if default_value.is_null() && !schema_ref[i].nullable() { - return Err(attach_span_if_absent( - DatabaseError::not_null_column(schema_ref[i].name().to_string()), - expr, - )); - } - row.push(default_value); - } - _ => return Err(DatabaseError::UnsupportedStmt(expr.to_string())), - } - } - rows.push(row); - } - self.context.allow_default = false; let values_plan = self.bind_values(rows, schema_ref); Ok(LogicalPlan::new( @@ -141,74 +46,10 @@ impl> Binder<'_, '_, T, A> pub(crate) fn bind_insert_query( &mut self, - name: &ObjectName, - idents: &[Ident], - query: &Query, + table_name: TableName, + input_plan: LogicalPlan, is_overwrite: bool, ) -> Result { - let table_name: TableName = lower_case_name(name)?.into(); - let table_schema = { - let source = self - .context - .source(&table_name)? - .ok_or(DatabaseError::TableNotFound)?; - let mut schema_buf = None; - source.schema_ref(&mut schema_buf) - }; - - let mut input_plan = self.bind_query(query)?; - let input_schema = input_plan.output_schema().clone(); - let input_len = input_schema.len(); - - let target_columns = if idents.is_empty() { - if input_len > table_schema.len() { - return Err(DatabaseError::ValuesLenMismatch( - table_schema.len(), - input_len, - )); - } - Cow::Borrowed(&table_schema[..input_len]) - } else { - let mut columns = Vec::with_capacity(idents.len()); - let source = self - .context - .source(&table_name)? - .ok_or(DatabaseError::TableNotFound)?; - let mut schema_buf = None; - for ident in idents { - let column_name = lower_ident(ident); - let column = source - .column(&column_name, &mut schema_buf) - .ok_or_else(|| { - attach_span_from_sqlparser_span_if_absent( - DatabaseError::column_not_found(column_name), - ident.span, - ) - })?; - columns.push(column); - } - if input_len != columns.len() { - return Err(DatabaseError::ValuesLenMismatch(columns.len(), input_len)); - } - Cow::Owned(columns) - }; - - let projection = input_schema - .iter() - .enumerate() - .zip(target_columns.iter()) - .map( - |((position, input_column), target_column)| ScalarExpression::Alias { - expr: Box::new(ScalarExpression::column_expr( - input_column.clone(), - position, - )), - alias: AliasType::Name(target_column.name().to_string()), - }, - ) - .collect::>(); - input_plan = self.bind_project(input_plan, projection)?; - Ok(LogicalPlan::new( Operator::Insert(InsertOperator { table_name, @@ -222,7 +63,7 @@ impl> Binder<'_, '_, T, A> pub(crate) fn bind_values( &mut self, rows: Vec>, - schema_ref: SchemaRef, + schema_ref: Schema, ) -> LogicalPlan { LogicalPlan::new( Operator::Values(ValuesOperator { rows, schema_ref }), diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 292f9a08..4323ae6d 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -12,9 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +macro_rules! with_query_bind_step { + ($binder:expr, $step:expr, $body:block) => {{ + let current_step = $binder.context.step_now(); + $binder.context.step($step); + let result = (|| -> Result<_, DatabaseError> { Ok($body) })(); + $binder.context.step(current_step); + result + }}; +} + +pub(crate) use with_query_bind_step; + pub mod aggregate; mod alter_table; mod analyze; +#[cfg(feature = "copy")] pub mod copy; mod create_index; mod create_table; @@ -28,32 +41,32 @@ mod drop_view; mod explain; pub mod expr; mod insert; +#[cfg(feature = "parser")] +mod parser; mod select; mod show_table; mod show_view; mod truncate; mod update; -use sqlparser::ast::{ - DescribeAlias, FromTable, Ident, ObjectName, ObjectNamePart, ObjectType, SetExpr, Spanned, - Statement, TableObject, -}; -use sqlparser::tokenizer::Span; -use std::borrow::Cow; +#[cfg(feature = "parser")] +pub use parser::{command_type, prepare, prepare_all, CommandType, Statement}; +#[cfg(feature = "orm")] +pub use select::{BindPlanFrom, BindPlanSelectList}; +#[cfg(feature = "orm")] +pub(crate) use select::{JoinConstraintInput, TableAliasInput}; use std::collections::{BTreeMap, HashMap}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; use crate::catalog::view::View; use crate::catalog::{ColumnRef, TableCatalog, TableName}; use crate::db::{ScalaFunctions, TableFunctions}; -use crate::errors::{DatabaseError, SqlErrorSpan}; +use crate::errors::DatabaseError; use crate::expression::ScalarExpression; use crate::planner::operator::join::JoinType; use crate::planner::operator::mark_apply::MarkApplyQuantifier; -use crate::planner::{LogicalPlan, SchemaOutput}; +use crate::planner::{LogicalPlan, PlanArena}; use crate::storage::{TableCache, Transaction, ViewCache}; -use crate::types::tuple::SchemaRef; +use crate::types::tuple::Schema; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -62,76 +75,11 @@ pub enum InputRefType { GroupBy, } -pub enum CommandType { - DQL, - DML, - DDL, -} - -fn annotate_bind_error(stmt: &Statement, err: DatabaseError) -> DatabaseError { - attach_span_if_absent(err, stmt) -} - -pub(crate) fn attach_span_from_sqlparser_span_if_absent( - err: DatabaseError, - span: Span, -) -> DatabaseError { - if err.sql_error_span().is_some() { - return err; - } - - match sqlparser_span_to_sql_error_span(span) { - Some(span) => err.with_span(span), - None => err, - } -} - -pub(crate) fn attach_span_if_absent( - err: DatabaseError, - node: &T, -) -> DatabaseError { - attach_span_from_sqlparser_span_if_absent(err, node.span()) -} - -pub(crate) fn sqlparser_span_to_sql_error_span(span: Span) -> Option { - if span == Span::empty() { - return None; - } - - let start = span.start.column as usize; - let mut end = span.end.column as usize; - if end <= start { - end = start.saturating_add(1); - } - - Some(SqlErrorSpan { - start, - end, - line: span.start.line as usize, - highlight: None, - }) -} - -pub fn command_type(stmt: &Statement) -> Result { - match stmt { - Statement::CreateTable(_) - | Statement::CreateIndex(_) - | Statement::CreateView(_) - | Statement::AlterTable(_) - | Statement::Drop { .. } => Ok(CommandType::DDL), - Statement::Query(_) - | Statement::Explain { .. } - | Statement::ExplainTable { .. } - | Statement::ShowTables { .. } - | Statement::ShowViews { .. } => Ok(CommandType::DQL), - Statement::Analyze(_) - | Statement::Truncate(_) - | Statement::Update(_) - | Statement::Delete(_) - | Statement::Insert(_) - | Statement::Copy { .. } => Ok(CommandType::DML), - stmt => Err(DatabaseError::UnsupportedStmt(stmt.to_string())), - } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) enum SetOperatorKind { + Union, + Except, + Intersect, } // Tips: only query now! @@ -173,7 +121,7 @@ pub enum SubQueryType { pub enum Source<'a> { Table(&'a TableCatalog), View(&'a View), - Schema(SchemaRef), + Schema(Schema), } #[derive(Debug, Clone)] @@ -233,21 +181,24 @@ impl UsingColumn { } fn left_expr(&self) -> ScalarExpression { - ScalarExpression::column_expr(self.left_column.clone(), self.left_position) + ScalarExpression::column_expr(self.left_column, self.left_position) } fn right_expr(&self) -> ScalarExpression { - ScalarExpression::column_expr(self.right_column.clone(), self.right_position) + ScalarExpression::column_expr(self.right_column, self.right_position) } - pub(crate) fn visible_expr(&self) -> Result { + pub(crate) fn visible_expr( + &self, + arena: &PlanArena, + ) -> Result { match self.join_type { JoinType::RightOuter => Ok(self.right_expr()), JoinType::Full => { let left_expr = self.left_expr(); let right_expr = self.right_expr(); - let left_ty = left_expr.return_type(); - let right_ty = right_expr.return_type(); + let left_ty = left_expr.return_type(arena); + let right_ty = right_expr.return_type(arena); let ty = LogicalType::max_logical_type(&left_ty, &right_ty)?.into_owned(); Ok(ScalarExpression::Coalesce { @@ -259,17 +210,16 @@ impl UsingColumn { } } - pub(crate) fn hides_column(&self, column: &ColumnRef) -> bool { + pub(crate) fn hides_column(&self, column: &ColumnRef, arena: &PlanArena) -> bool { let hidden_column = if self.join_type.is_right() { &self.left_column } else { &self.right_column }; - hidden_column.same_column(column) + arena.same_column(*hidden_column, *column) } } -#[derive(Clone)] pub struct BinderContext<'a, T: Transaction> { pub(crate) scala_functions: &'a ScalaFunctions, pub(crate) table_functions: &'a TableFunctions, @@ -292,37 +242,38 @@ pub struct BinderContext<'a, T: Transaction> { sub_queries: HashMap>, has_outer_refs: bool, - temp_table_id: Arc, pub(crate) allow_default: bool, } impl Source<'_> { - pub(crate) fn column( - &self, - name: &str, - schema_buf: &mut Option, - ) -> Option { + pub(crate) fn column(&self, name: &str, arena: &PlanArena) -> Option { match self { Source::Table(table) => table.get_column_by_name(name), - Source::View(view) => schema_buf - .get_or_insert_with(|| view.plan.output_schema_direct()) - .columns() - .find(|column| column.name() == name), - Source::Schema(schema_ref) => schema_ref.iter().find(|column| column.name() == name), + Source::View(view) => view + .schema + .iter() + .find(|column| arena.column(**column).name() == name) + .copied(), + Source::Schema(schema_ref) => schema_ref + .iter() + .find(|column| arena.column(**column).name() == name) + .copied(), } - .cloned() } - pub(crate) fn schema_ref(&self, schema_buf: &mut Option) -> SchemaRef { + pub(crate) fn schema(&self) -> &[ColumnRef] { match self { - Source::Table(table) => table.schema_ref().clone(), - Source::View(view) => { - match schema_buf.get_or_insert_with(|| view.plan.output_schema_direct()) { - SchemaOutput::Schema(schema) => Arc::new(schema.clone()), - SchemaOutput::SchemaRef(schema_ref) => schema_ref.clone(), - } - } - Source::Schema(schema_ref) => schema_ref.clone(), + Source::Table(table) => table.columns().as_slice(), + Source::View(view) => &view.schema, + Source::Schema(schema_ref) => schema_ref, + } + } + + pub(crate) fn schema_len(&self) -> usize { + match self { + Source::Table(table) => table.columns_len(), + Source::View(view) => view.schema.len(), + Source::Schema(schema_ref) => schema_ref.len(), } } } @@ -334,7 +285,6 @@ impl<'a, T: Transaction> BinderContext<'a, T> { transaction: &'a T, scala_functions: &'a ScalaFunctions, table_functions: &'a TableFunctions, - temp_table_id: Arc, ) -> Self { BinderContext { scala_functions, @@ -351,17 +301,46 @@ impl<'a, T: Transaction> BinderContext<'a, T> { bind_step: QueryBindStep::From, sub_queries: Default::default(), has_outer_refs: false, - temp_table_id, allow_default: false, } } - pub fn temp_table(&mut self) -> TableName { - format!( - "_temp_table_{}_", - self.temp_table_id.fetch_add(1, Ordering::SeqCst) + /// Creates a child context that starts with the current binding scope. + /// + /// This is used for nested query bodies that should be able to resolve the + /// same local sources and aliases while keeping their mutations isolated. + pub(crate) fn fork(&self) -> Self { + BinderContext { + scala_functions: self.scala_functions, + table_functions: self.table_functions, + table_cache: self.table_cache, + view_cache: self.view_cache, + transaction: self.transaction, + bind_table: self.bind_table.clone(), + expr_aliases: self.expr_aliases.clone(), + table_aliases: self.table_aliases.clone(), + group_by_exprs: self.group_by_exprs.clone(), + agg_calls: self.agg_calls.clone(), + using: self.using.clone(), + bind_step: self.bind_step, + sub_queries: Default::default(), + has_outer_refs: false, + allow_default: self.allow_default, + } + } + + /// Creates a child context with shared catalogs but without local bindings. + /// + /// This is used while binding an independent input, such as the right side + /// of a join, before merging its newly bound sources into the parent scope. + pub(crate) fn fork_empty(&self) -> Self { + BinderContext::new( + self.table_cache, + self.view_cache, + self.transaction, + self.scala_functions, + self.table_functions, ) - .into() } pub fn step(&mut self, bind_step: QueryBindStep) { @@ -550,9 +529,9 @@ impl<'a, T: Transaction> BinderContext<'a, T> { } let using_column = UsingColumn::new( join_type, - left_column.clone(), + *left_column, left_position, - right_column.clone(), + *right_column, right_position, ); self.using.insert(name, using_column); @@ -577,23 +556,21 @@ impl<'a, T: Transaction> BinderContext<'a, T> { } } -pub struct Binder<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> { - context: BinderContext<'a, T>, - table_schema_buf: HashMap>, - args: &'a A, +pub struct Binder<'a, 'parent, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> { + pub(crate) context: BinderContext<'a, T>, + pub(crate) args: &'a A, with_pk: Option, - pub(crate) parent: Option<&'b Binder<'a, 'b, T, A>>, + pub(crate) parent: Option<&'parent BinderContext<'a, T>>, } -impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'b, T, A> { +impl<'a, 'parent, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'parent, T, A> { pub fn new( context: BinderContext<'a, T>, args: &'a A, - parent: Option<&'b Binder<'a, 'b, T, A>>, + parent: Option<&'parent BinderContext<'a, T>>, ) -> Self { Binder { context, - table_schema_buf: Default::default(), args, with_pk: None, parent, @@ -604,6 +581,10 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' self.with_pk = Some(table_name); } + pub fn clear_with_pk(&mut self) { + self.with_pk = None; + } + pub fn is_scan_with_pk(&self, table_name: &TableName) -> bool { if let Some(with_pk_table) = self.with_pk.as_ref() { return with_pk_table == table_name; @@ -611,162 +592,7 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' false } - fn bind_inner(&mut self, stmt: &Statement) -> Result { - let plan = match stmt { - Statement::Query(query) => self.bind_query(query)?, - Statement::AlterTable(alter) => { - if alter.operations.len() != 1 { - return Err(DatabaseError::UnsupportedStmt( - "only a single ALTER TABLE operation is supported".to_string(), - )); - } - self.bind_alter_table(&alter.name, &alter.operations[0])? - } - Statement::CreateTable(create) => self.bind_create_table( - &create.name, - &create.columns, - &create.constraints, - create.if_not_exists, - )?, - Statement::Drop { - object_type, - names, - if_exists, - .. - } => { - if names.len() > 1 { - return Err(DatabaseError::UnsupportedStmt( - "only Drop a single `Table` or `View` is allowed".to_string(), - )); - } - match object_type { - ObjectType::Table => self.bind_drop_table(&names[0], if_exists)?, - ObjectType::View => self.bind_drop_view(&names[0], if_exists)?, - ObjectType::Index => self.bind_drop_index(&names[0], if_exists)?, - _ => { - return Err(DatabaseError::UnsupportedStmt( - "only `Table` and `View` are allowed to be Dropped".to_string(), - )) - } - } - } - Statement::Insert(insert) => { - let table_name = match &insert.table { - TableObject::TableName(table_name) => table_name, - TableObject::TableFunction(_) => { - return Err(DatabaseError::UnsupportedStmt( - "insert into table function is not supported".to_string(), - )) - } - }; - let source = insert.source.as_ref().ok_or_else(|| { - DatabaseError::UnsupportedStmt( - "insert without source is not supported".to_string(), - ) - })?; - match source.body.as_ref() { - SetExpr::Values(values) => self.bind_insert( - table_name, - &insert.columns, - &values.rows, - insert.overwrite, - false, - )?, - _ => self.bind_insert_query( - table_name, - &insert.columns, - source, - insert.overwrite, - )?, - } - } - Statement::Update(update) => { - let table = &update.table; - self.bind_update(table, &update.selection, &update.assignments)? - } - Statement::Delete(delete) => { - let from = match &delete.from { - FromTable::WithFromKeyword(from) | FromTable::WithoutKeyword(from) => from, - }; - let table = &from[0]; - - self.bind_delete(table, &delete.selection)? - } - Statement::Analyze(analyze) => { - let table_name = analyze.table_name.as_ref().ok_or_else(|| { - DatabaseError::UnsupportedStmt( - "ANALYZE without table is not supported".to_string(), - ) - })?; - self.bind_analyze(table_name)? - } - Statement::Truncate(truncate) => { - if truncate.table_names.len() != 1 { - return Err(DatabaseError::UnsupportedStmt( - "only truncate a single table is supported".to_string(), - )); - } - self.bind_truncate(&truncate.table_names[0].name)? - } - Statement::ShowTables { .. } => self.bind_show_tables()?, - Statement::ShowViews { .. } => self.bind_show_views()?, - Statement::Copy { - source, - to, - target, - options, - .. - } => self.bind_copy(source.clone(), *to, target.clone(), options)?, - Statement::Explain { statement, .. } => { - let plan = self.bind_inner(statement)?; - - self.bind_explain(plan)? - } - Statement::ExplainTable { - describe_alias: DescribeAlias::Describe | DescribeAlias::Desc, - table_name, - .. - } => self.bind_describe(table_name)?, - Statement::CreateIndex(create) => self.bind_create_index( - &create.table_name, - create.name.as_ref(), - &create.columns, - create.if_not_exists, - create.unique, - )?, - Statement::CreateView(create) => self.bind_create_view( - &create.or_replace, - &create.name, - &create.columns, - &create.query, - )?, - _ => return Err(DatabaseError::UnsupportedStmt(stmt.to_string())), - }; - Ok(plan) - } - - pub fn bind(&mut self, stmt: &Statement) -> Result { - self.bind_inner(stmt) - .map_err(|err| annotate_bind_error(stmt, err)) - } - - pub fn bind_set_expr(&mut self, set_expr: &SetExpr) -> Result { - match set_expr { - SetExpr::Select(select) => self.bind_select(select, None), - SetExpr::Query(query) => self.bind_query(query), - SetExpr::SetOperation { - op, - set_quantifier, - left, - right, - } => self.bind_set_operation(op, set_quantifier, left, right), - expr => Err(DatabaseError::UnsupportedStmt(format!( - "set expression: {expr:?}" - ))), - } - } - - fn extend(&mut self, context: BinderContext<'a, T>) { + pub(crate) fn extend(&mut self, context: BinderContext<'a, T>) { for bound_source in context.bind_table { self.context.add_bound_source( bound_source.table_name, @@ -784,33 +610,6 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' } } -fn lower_ident(ident: &Ident) -> Cow<'_, str> { - let value = &ident.value; - - if value.chars().any(char::is_uppercase) { - Cow::Owned(value.to_lowercase()) - } else { - Cow::Borrowed(value) - } -} - -fn lower_name_part(part: &ObjectNamePart) -> Result, DatabaseError> { - part.as_ident() - .map(lower_ident) - .ok_or_else(|| attach_span_if_absent(DatabaseError::invalid_table(part.to_string()), part)) -} - -/// Convert an object name into lower case -fn lower_case_name(name: &ObjectName) -> Result, DatabaseError> { - if name.0.len() == 1 { - return lower_name_part(&name.0[0]); - } - Err(attach_span_if_absent( - DatabaseError::invalid_table(name.to_string()), - name, - )) -} - pub(crate) fn is_valid_identifier(s: &str) -> bool { s.chars().all(|c| c.is_alphanumeric() || c == '_') && !s.chars().next().unwrap_or_default().is_numeric() @@ -822,27 +621,33 @@ pub mod test { use crate::binder::{is_valid_identifier, Binder, BinderContext}; use crate::catalog::{ColumnCatalog, ColumnDesc, TableCatalog}; use crate::errors::DatabaseError; - use crate::planner::LogicalPlan; + use crate::planner::{LogicalPlan, PlanArena, TableArenaCell}; use crate::storage::rocksdb::RocksStorage; - use crate::storage::{Storage, TableCache, Transaction, ViewCache}; + use crate::storage::{table_codec::TableCodec, Storage, TableCache, Transaction, ViewCache}; use crate::types::ColumnId; use crate::types::LogicalType::Integer; - use crate::utils::lru::SharedLruCache; - use std::hash::RandomState; use std::path::PathBuf; - use std::sync::atomic::AtomicUsize; - use std::sync::Arc; use tempfile::TempDir; pub(crate) struct TableState { pub(crate) table: TableCatalog, - pub(crate) table_cache: Arc, - pub(crate) view_cache: Arc, + pub(crate) table_cache: TableCache, + pub(crate) view_cache: ViewCache, + pub(crate) table_arena: TableArenaCell, pub(crate) storage: S, } impl TableState { pub(crate) fn plan>(&self, sql: T) -> Result { + let mut plan_arena = PlanArena::new(&self.table_arena); + self.plan_with_arena(sql, &mut plan_arena) + } + + pub(crate) fn plan_with_arena>( + &self, + sql: T, + plan_arena: &mut PlanArena, + ) -> Result { let scala_functions = Default::default(); let table_functions = Default::default(); let transaction = self.storage.transaction()?; @@ -853,14 +658,14 @@ pub mod test { &transaction, &scala_functions, &table_functions, - Arc::new(AtomicUsize::new(0)), ), &[], None, ); let stmt = crate::parser::parse_sql(sql)?; + let stmt = stmt.into_iter().next().unwrap(); - binder.bind(&stmt[0]) + binder.bind(&stmt, plan_arena) } pub(crate) fn column_id_by_name(&self, name: &str) -> &ColumnId { @@ -870,9 +675,10 @@ pub mod test { pub(crate) fn build_t1_table() -> Result, DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let storage = build_test_catalog(&table_cache, temp_dir.path())?; + let mut table_cache = crate::storage::TableCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_arena = TableArenaCell::default(); + let storage = build_test_catalog(&mut table_cache, temp_dir.path(), &table_arena)?; let table = { let transaction = storage.transaction()?; transaction @@ -885,19 +691,24 @@ pub mod test { table, table_cache, view_cache, + table_arena, storage, }) } pub(crate) fn build_test_catalog( - table_cache: &TableCache, + table_cache: &mut TableCache, path: impl Into + Send, + table_arena: &TableArenaCell, ) -> Result { let storage = RocksStorage::new(path)?; let mut transaction = storage.transaction()?; + let mut table_codec = TableCodec::default(); - let _ = transaction.create_table( - table_cache, + let mut plan_arena = PlanArena::new(table_arena); + if let Some(table) = transaction.create_table( + &mut table_codec, + &mut plan_arena, "t1".to_string().into(), vec![ ColumnCatalog::new( @@ -912,10 +723,15 @@ pub mod test { ), ], false, - )?; + )? { + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); + } - let _ = transaction.create_table( - table_cache, + let mut plan_arena = PlanArena::new(table_arena); + if let Some(table) = transaction.create_table( + &mut table_codec, + &mut plan_arena, "t2".to_string().into(), vec![ ColumnCatalog::new( @@ -930,8 +746,10 @@ pub mod test { ), ], false, - )?; - + )? { + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); + } transaction.commit()?; Ok(storage) diff --git a/src/binder/parser.rs b/src/binder/parser.rs new file mode 100644 index 00000000..0e158886 --- /dev/null +++ b/src/binder/parser.rs @@ -0,0 +1,2843 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::select::{ + BindPlanAggregated, BindPlanComplete, BindPlanDistinct, BindPlanFiltered, BindPlanFrom, + BindPlanHaving, BindPlanProjected, BindPlanSelectList, BindPlanStart, JoinConstraintInput, + TableAliasInput, +}; +use super::{is_valid_identifier, with_query_bind_step, Binder, QueryBindStep, SetOperatorKind}; +#[cfg(feature = "copy")] +use crate::binder::copy::{ExtSource, FileFormat}; +use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, TableName}; +use crate::db::{BindSource, DBTransaction, Database, DatabaseIter, TransactionIter}; +use crate::errors::{DatabaseError, SqlErrorSpan}; +use crate::expression; +use crate::expression::simplify::ConstantCalculator; +use crate::expression::visitor_mut::VisitorMut; +use crate::expression::{AliasType, ScalarExpression}; +use crate::parser::parse_sql; +use crate::planner::operator::alter_table::change_column::{DefaultChange, NotNullChange}; +use crate::planner::operator::join::{JoinCondition, JoinOperator as LJoinOperator, JoinType}; +use crate::planner::operator::mark_apply::MarkApplyQuantifier; +use crate::planner::operator::project::ProjectOperator; +use crate::planner::operator::sort::SortField; +use crate::planner::operator::Operator; +use crate::planner::{Childrens, LogicalPlan, PlanArena}; +use crate::storage::{Storage, Transaction}; +use crate::types::value::{DataValue, Utf8Type}; +use crate::types::{CharLengthUnits, ColumnId, LogicalType}; +use itertools::Itertools; +pub(super) use sqlparser::ast::{ + AlterColumnOperation, AlterTableOperation, Assignment, AssignmentTarget, BinaryOperator, + ColumnDef, ColumnOption, CreateView, DataType, DescribeAlias, Distinct, DuplicateTreatment, + Expr, FromTable, Function, FunctionArg, FunctionArgExpr, FunctionArguments, GroupByExpr, Ident, + IndexColumn, Join, JoinConstraint, JoinOperator, LimitClause, ObjectName, ObjectNamePart, + ObjectType, OrderByExpr, OrderByKind, Query, Select, SelectInto, SelectItem, + SelectItemQualifiedWildcardKind, SetExpr, SetOperator, SetQuantifier, Spanned, TableAlias, + TableConstraint, TableFactor, TableObject, TableWithJoins, TypedString, UnaryOperator, Value, +}; +#[cfg(feature = "copy")] +pub(super) use sqlparser::ast::{CopyOption, CopySource, CopyTarget}; +use sqlparser::tokenizer::Span; +use std::borrow::{Borrow, Cow}; +use std::cmp; +use std::slice; + +/// Parsed SQL statement type used by KiteSQL SQL frontend APIs. +pub type Statement = sqlparser::ast::Statement; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum CommandType { + DQL, + DML, + DDL, + Analyze, +} + +pub(crate) trait AttachSpanSource { + fn sql_error_span(self) -> Option; +} + +impl AttachSpanSource for &T { + fn sql_error_span(self) -> Option { + self.span().sql_error_span() + } +} + +impl AttachSpanSource for Span { + fn sql_error_span(self) -> Option { + if self == Span::empty() { + return None; + } + + let start = self.start.column as usize; + let mut end = self.end.column as usize; + if end <= start { + end = start.saturating_add(1); + } + + Some(SqlErrorSpan { + start, + end, + line: self.start.line as usize, + highlight: None, + }) + } +} + +pub(crate) fn attach_span_if_absent( + err: DatabaseError, + source: T, +) -> DatabaseError { + if err.sql_error_span().is_some() { + return err; + } + + match source.sql_error_span() { + Some(span) => err.with_span(span), + None => err, + } +} + +pub fn command_type(stmt: &Statement) -> Result { + match stmt { + Statement::CreateTable(_) + | Statement::CreateIndex(_) + | Statement::CreateView(_) + | Statement::AlterTable(_) + | Statement::Drop { .. } => Ok(CommandType::DDL), + Statement::Query(_) + | Statement::Explain { .. } + | Statement::ExplainTable { .. } + | Statement::ShowTables { .. } + | Statement::ShowViews { .. } => Ok(CommandType::DQL), + Statement::Analyze(_) => Ok(CommandType::Analyze), + Statement::Truncate(_) + | Statement::Update(_) + | Statement::Delete(_) + | Statement::Insert(_) => Ok(CommandType::DML), + #[cfg(feature = "copy")] + Statement::Copy { .. } => Ok(CommandType::DML), + stmt => Err(DatabaseError::UnsupportedStmt(stmt.to_string())), + } +} + +/// Parses a single SQL statement into a reusable [`Statement`]. +pub fn prepare>(sql: T) -> Result { + let mut stmts = prepare_all(sql)?; + stmts.pop().ok_or(DatabaseError::EmptyStatement) +} + +/// Parses one or more SQL statements into a vector of [`Statement`] values. +pub fn prepare_all>(sql: T) -> Result, DatabaseError> { + let stmts = parse_sql(sql)?; + if stmts.is_empty() { + return Err(DatabaseError::EmptyStatement); + } + Ok(stmts) +} + +fn statement_mutates_catalog_or_statistics(statement: &Statement) -> Result { + Ok(matches!( + command_type(statement)?, + CommandType::DDL | CommandType::Analyze + )) +} + +impl Database { + /// Executes a prepared [`Statement`] inside a database-owned transaction. + pub fn execute( + &self, + statement: St, + params: A, + ) -> Result, DatabaseError> + where + A: AsRef<[(&'static str, DataValue)]>, + St: Borrow, + { + if statement_mutates_catalog_or_statistics(statement.borrow())? { + return Err(DatabaseError::UnsupportedStmt( + "DDL and ANALYZE require `Database::ddl` or `Database::analyze`".to_string(), + )); + } + BindSource::execute(self, params, |binder, arena| { + binder.bind(statement.borrow(), arena) + }) + } + + pub fn ddl>(&mut self, sql: T) -> Result<(), DatabaseError> { + let sql = sql.as_ref(); + let statements = prepare_all(sql).map_err(|err| err.with_sql_context(sql))?; + + for statement in statements { + if !matches!(command_type(&statement)?, CommandType::DDL) { + return Err(DatabaseError::UnsupportedStmt( + "`Database::ddl` only accepts DDL statements".to_string(), + ) + .with_sql_context(sql)); + } + + self.execute_mut(sql, &[], |binder, arena| binder.bind(&statement, arena))?; + } + + Ok(()) + } + + /// Runs one or more SQL statements and returns an iterator for the final result set. + /// + /// Earlier statements in the same SQL string are executed eagerly. The last + /// statement is exposed as a streaming iterator. + /// + /// # Examples + /// + /// ```rust + /// use kite_sql::db::{DataBaseBuilder, ResultIter}; + /// + /// let mut database = DataBaseBuilder::path(".").build_in_memory().unwrap(); + /// database.ddl("create table t (id int primary key)").unwrap(); + /// let mut iter = database.run("select * from t").unwrap(); + /// iter.schema(|schema| assert_eq!(schema.len(), 1)); + /// iter.done().unwrap(); + /// ``` + pub fn run>(&self, sql: T) -> Result, DatabaseError> { + let sql = sql.as_ref(); + let statements = prepare_all(sql).map_err(|err| err.with_sql_context(sql))?; + let has_catalog_mutation = statements + .iter() + .try_fold(false, |has_mutation, stmt| { + Ok::<_, DatabaseError>( + has_mutation || statement_mutates_catalog_or_statistics(stmt)?, + ) + }) + .map_err(|err| err.with_sql_context(sql))?; + + if has_catalog_mutation { + return Err(DatabaseError::UnsupportedStmt( + "DDL and ANALYZE require `Database::ddl` or `Database::analyze`".to_string(), + ) + .with_sql_context(sql)); + } + + let transaction = Box::into_raw(Box::new( + self.storage + .transaction_with_isolation(self.transaction_isolation)?, + )); + let mut statements = statements.into_iter().peekable(); + + while let Some(statement) = statements.next() { + let (schema, plan_arena, executor) = + match self + .state + .execute(unsafe { &mut *transaction }, &[], |binder, arena| { + binder.bind(&statement, arena) + }) { + Ok(result) => result, + Err(err) => { + unsafe { drop(Box::from_raw(transaction)) }; + return Err(err.with_sql_context(sql)); + } + }; + + if statements.peek().is_some() { + if let Err(err) = + TransactionIter::new(schema, plan_arena, executor, transaction).done() + { + unsafe { drop(Box::from_raw(transaction)) }; + return Err(err.with_sql_context(sql)); + } + } else { + let inner = Box::into_raw(Box::new(TransactionIter::new( + schema, + plan_arena, + executor, + transaction, + ))); + return Ok(DatabaseIter { transaction, inner }); + } + } + + unsafe { drop(Box::from_raw(transaction)) }; + Err(DatabaseError::EmptyStatement.with_sql_context(sql)) + } +} + +impl<'txn, S: Storage> DBTransaction<'txn, S> { + /// Executes a prepared [`Statement`] inside the current transaction. + pub fn execute<'a, A, St>( + &'a mut self, + statement: St, + params: A, + ) -> Result>, DatabaseError> + where + A: AsRef<[(&'static str, DataValue)]>, + St: Borrow, + { + if matches!( + command_type(statement.borrow())?, + CommandType::DDL | CommandType::Analyze + ) { + return Err(DatabaseError::UnsupportedStmt( + "`DDL` and `ANALYZE` are not allowed to execute within a transaction".to_string(), + )); + } + BindSource::execute(self, params, |binder, arena| { + binder.bind(statement.borrow(), arena) + }) + } + + /// Runs SQL inside the current transaction and returns the final result iterator. + pub fn run<'a, T: AsRef>( + &'a mut self, + sql: T, + ) -> Result>, DatabaseError> { + let sql = sql.as_ref(); + let mut statements = prepare_all(sql).map_err(|err| err.with_sql_context(sql))?; + let last_statement = statements + .pop() + .ok_or_else(|| DatabaseError::EmptyStatement.with_sql_context(sql))?; + + for statement in statements { + self.execute(&statement, &[]) + .map_err(|err| err.with_sql_context(sql))? + .done() + .map_err(|err| err.with_sql_context(sql))?; + } + + self.execute(&last_statement, &[]) + .map_err(|err| err.with_sql_context(sql)) + } +} + +struct BindStatementStart<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + binder: &'s mut Binder<'a, 'b, T, A>, + arena: &'s mut PlanArena<'arena>, +} + +struct BindStatementComplete { + plan: LogicalPlan, +} + +struct UpdateExprTargetRemapper<'a, 'p> { + target_schema: &'a [ColumnRef], + arena: &'a PlanArena<'p>, +} + +impl VisitorMut<'_> for UpdateExprTargetRemapper<'_, '_> { + fn visit_column_ref( + &mut self, + column: &mut ColumnRef, + position: &mut usize, + ) -> Result<(), DatabaseError> { + let Some(target_position) = self + .target_schema + .iter() + .copied() + .position(|target_column| self.arena.same_column(target_column, *column)) + else { + return Err(DatabaseError::UnsupportedStmt( + "joined UPDATE SET expressions can only reference target table columns".to_string(), + )); + }; + *position = target_position; + Ok(()) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindStatementStart<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn statement(self, stmt: &Statement) -> Result { + let span = stmt.span(); + (|| { + let plan = match stmt { + Statement::Query(query) => self.binder.bind_query(query, self.arena)?, + Statement::AlterTable(alter) => self.alter_table(alter.clone())?, + Statement::CreateTable(create) => self.create_table(create.clone())?, + Statement::Drop { + object_type, + names, + if_exists, + .. + } => self.drop_object(*object_type, names.clone(), *if_exists)?, + Statement::Insert(insert) => self.insert(insert)?, + Statement::Update(update) => self.update(update)?, + Statement::Delete(delete) => self.delete(delete)?, + Statement::Analyze(analyze) => self.analyze(analyze.clone())?, + Statement::Truncate(truncate) => self.truncate(truncate.clone())?, + Statement::ShowTables { .. } => self.binder.bind_show_tables()?, + Statement::ShowViews { .. } => self.binder.bind_show_views()?, + #[cfg(feature = "copy")] + Statement::Copy { + source, + to, + target, + options, + .. + } => self.copy(source.clone(), *to, target.clone(), options.clone())?, + #[cfg(not(feature = "copy"))] + Statement::Copy { .. } => { + return Err(DatabaseError::UnsupportedStmt( + "COPY requires the `copy` feature".to_string(), + )) + } + Statement::Explain { statement, .. } => self.explain(statement)?, + Statement::ExplainTable { + describe_alias: DescribeAlias::Describe | DescribeAlias::Desc, + table_name, + .. + } => self + .binder + .bind_describe(sql_table_name(table_name.clone())?)?, + Statement::CreateIndex(create) => self.create_index(create.clone())?, + Statement::CreateView(create) => self.create_view(create.clone())?, + _ => return Err(DatabaseError::UnsupportedStmt(stmt.to_string())), + }; + + Ok(BindStatementComplete { plan }) + })() + .map_err(|err| attach_span_if_absent(err, span)) + } + + fn alter_table(self, alter: sqlparser::ast::AlterTable) -> Result { + if alter.operations.len() != 1 { + return Err(DatabaseError::UnsupportedStmt( + "only a single ALTER TABLE operation is supported".to_string(), + )); + } + let operation = alter.operations.into_iter().next().unwrap(); + self.alter_table_operation(sql_table_name(alter.name)?, operation) + } + + fn alter_table_operation( + mut self, + table_name: TableName, + operation: AlterTableOperation, + ) -> Result { + self.binder + .context + .table(table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + + match operation { + AlterTableOperation::AddColumn { + column_keyword: _, + if_not_exists, + column_def, + .. + } => { + let column_span = column_def.name.span; + let column = self.bind_column(column_def, None)?; + + if !is_valid_identifier(column.name()) { + return Err(attach_span_if_absent( + DatabaseError::invalid_column("illegal column naming".to_string()), + column_span, + )); + } + + self.binder + .bind_add_column(table_name, column, if_not_exists) + } + AlterTableOperation::DropColumn { + column_names, + if_exists, + .. + } => { + if column_names.len() != 1 { + return Err(DatabaseError::UnsupportedStmt( + "only dropping a single column is supported".to_string(), + )); + } + let column_name = column_names[0].value.clone(); + + self.binder + .bind_drop_column(table_name, column_name, if_exists) + } + AlterTableOperation::RenameColumn { + old_column_name, + new_column_name, + } => { + let old_column_name = lower_ident(&old_column_name); + let new_column_name = lower_ident(&new_column_name).into_owned(); + let data_type = { + let table = self + .binder + .context + .table(table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + table + .get_column_by_name(old_column_name.as_ref()) + .map(|column| self.arena.column(column).datatype().clone()) + .ok_or_else(|| { + DatabaseError::column_not_found(old_column_name.to_string()) + })? + }; + + if !is_valid_identifier(&new_column_name) { + return Err(DatabaseError::invalid_column( + "illegal column naming".to_string(), + )); + } + + self.binder.bind_change_column( + table_name, + old_column_name.into_owned(), + new_column_name, + data_type, + DefaultChange::NoChange, + NotNullChange::NoChange, + ) + } + AlterTableOperation::AlterColumn { column_name, op } => { + let old_column_name = lower_ident(&column_name); + let old_data_type = { + let table = self + .binder + .context + .table(table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + table + .get_column_by_name(old_column_name.as_ref()) + .map(|column| self.arena.column(column).datatype().clone()) + .ok_or_else(|| { + DatabaseError::column_not_found(old_column_name.to_string()) + })? + }; + + let (data_type, default_change, not_null_change) = match op { + AlterColumnOperation::SetDataType { + data_type, using, .. + } => { + if using.is_some() { + return Err(DatabaseError::UnsupportedStmt( + "ALTER COLUMN TYPE USING is not supported".to_string(), + )); + } + ( + LogicalType::try_from(data_type)?, + DefaultChange::NoChange, + NotNullChange::NoChange, + ) + } + AlterColumnOperation::SetDefault { value } => ( + old_data_type.clone(), + DefaultChange::Set(self.bind_alter_default_expr(value, &old_data_type)?), + NotNullChange::NoChange, + ), + AlterColumnOperation::DropDefault => ( + old_data_type.clone(), + DefaultChange::Drop, + NotNullChange::NoChange, + ), + AlterColumnOperation::SetNotNull => ( + old_data_type.clone(), + DefaultChange::NoChange, + NotNullChange::Set, + ), + AlterColumnOperation::DropNotNull => ( + old_data_type.clone(), + DefaultChange::NoChange, + NotNullChange::Drop, + ), + _ => { + return Err(DatabaseError::UnsupportedStmt(format!( + "unsupported alter column operation: {op:?}" + ))) + } + }; + + self.binder.bind_change_column( + table_name, + old_column_name.to_string(), + old_column_name.into_owned(), + data_type, + default_change, + not_null_change, + ) + } + AlterTableOperation::ModifyColumn { + col_name, + data_type, + options, + column_position, + } => { + if column_position.is_some() { + return Err(DatabaseError::UnsupportedStmt( + "MODIFY COLUMN does not currently support column positions".to_string(), + )); + } + let old_column_name = lower_ident(&col_name); + { + let table = self + .binder + .context + .table(table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + let _ = table + .get_column_by_name(old_column_name.as_ref()) + .ok_or_else(|| { + DatabaseError::column_not_found(old_column_name.to_string()) + })?; + } + let old_column_name = old_column_name.into_owned(); + let data_type = LogicalType::try_from(data_type)?; + let (default_change, not_null_change) = + self.bind_change_column_options(options, &data_type)?; + + self.binder.bind_change_column( + table_name, + old_column_name.clone(), + old_column_name, + data_type, + default_change, + not_null_change, + ) + } + AlterTableOperation::ChangeColumn { + old_name, + new_name, + data_type, + options, + column_position, + } => { + if column_position.is_some() { + return Err(DatabaseError::UnsupportedStmt( + "CHANGE COLUMN does not currently support column positions".to_string(), + )); + } + let old_column_name = lower_ident(&old_name); + let new_column_name = lower_ident(&new_name).into_owned(); + { + let table = self + .binder + .context + .table(table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + let _ = table + .get_column_by_name(old_column_name.as_ref()) + .ok_or_else(|| { + DatabaseError::column_not_found(old_column_name.to_string()) + })?; + } + + if !is_valid_identifier(&new_column_name) { + return Err(DatabaseError::invalid_column( + "illegal column naming".to_string(), + )); + } + let data_type = LogicalType::try_from(data_type)?; + let (default_change, not_null_change) = + self.bind_change_column_options(options, &data_type)?; + + self.binder.bind_change_column( + table_name, + old_column_name.into_owned(), + new_column_name, + data_type, + default_change, + not_null_change, + ) + } + op => Err(DatabaseError::UnsupportedStmt(format!( + "AlertOperation: {op:?}" + ))), + } + } + + fn bind_alter_default_expr( + &mut self, + expr: Expr, + ty: &LogicalType, + ) -> Result { + let mut expr = self.binder.bind_expr(&expr, self.arena)?; + + if expr.any_referenced_column(self.arena, |_, _| true) { + return Err(DatabaseError::UnsupportedStmt( + "column is not allowed to exist in default".to_string(), + )); + } + expr = ScalarExpression::type_cast(expr, Cow::Borrowed(ty), self.arena)?; + + Ok(expr) + } + + fn bind_change_column_options( + &mut self, + options: Vec, + data_type: &LogicalType, + ) -> Result<(DefaultChange, NotNullChange), DatabaseError> { + let mut default_change = DefaultChange::NoChange; + let mut not_null_change = NotNullChange::NoChange; + + for option in options { + match option { + ColumnOption::Null => not_null_change = NotNullChange::Drop, + ColumnOption::NotNull => not_null_change = NotNullChange::Set, + ColumnOption::Default(expr) => { + default_change = + DefaultChange::Set(self.bind_alter_default_expr(expr, data_type)?); + } + option => { + return Err(DatabaseError::UnsupportedStmt(format!( + "CHANGE/MODIFY COLUMN does not currently support this option: {option:?}" + ))) + } + } + } + + Ok((default_change, not_null_change)) + } + + fn create_table( + mut self, + create: sqlparser::ast::CreateTable, + ) -> Result { + let table_name = sql_table_name(create.name.clone())?; + + if !is_valid_identifier(&table_name) { + return Err(attach_span_if_absent( + DatabaseError::invalid_table("illegal table naming".to_string()), + &create.name, + )); + } + for col in create.columns.iter() { + let col_name = &col.name.value; + if !is_valid_identifier(col_name) { + return Err(attach_span_if_absent( + DatabaseError::invalid_column("illegal column naming".to_string()), + col, + )); + } + } + + let mut columns = Vec::with_capacity(create.columns.len()); + for (i, column) in create.columns.into_iter().enumerate() { + columns.push(self.bind_column(column, Some(i))?); + } + for constraint in create.constraints { + match constraint { + TableConstraint::PrimaryKey(primary) => { + self.bind_constraint(&mut columns, primary.columns, |i, desc| { + desc.set_primary(Some(i)) + })?; + } + TableConstraint::Unique(unique) => { + self.bind_constraint(&mut columns, unique.columns, |_, desc| { + desc.set_unique() + })?; + } + constraint => { + return Err(DatabaseError::UnsupportedStmt(format!( + "`CreateTable` does not currently support this constraint: {constraint:?}" + ))) + } + } + } + + self.binder + .bind_create_table(table_name, columns, create.if_not_exists) + } + + fn bind_column( + &mut self, + column_def: ColumnDef, + column_index: Option, + ) -> Result { + let column_name = lower_ident(&column_def.name).into_owned(); + let mut column_desc = ColumnDesc::new( + LogicalType::try_from(column_def.data_type)?, + None, + false, + None, + )?; + let mut nullable = true; + + for option_def in column_def.options { + match option_def.option { + ColumnOption::Null => nullable = true, + ColumnOption::NotNull => nullable = false, + ColumnOption::PrimaryKey(_) => { + column_desc.set_primary(column_index); + nullable = false; + break; + } + ColumnOption::Unique(_) => column_desc.set_unique(), + ColumnOption::Default(expr) => { + let mut expr = self.binder.bind_expr(&expr, self.arena)?; + + if expr.any_referenced_column(self.arena, |_, _| true) { + return Err(DatabaseError::UnsupportedStmt( + "column is not allowed to exist in `default`".to_string(), + )); + } + expr = ScalarExpression::type_cast( + expr, + Cow::Borrowed(&column_desc.column_datatype), + self.arena, + )?; + column_desc.default = Some(expr); + } + option => { + return Err(DatabaseError::UnsupportedStmt(format!( + "`Column` does not currently support this option: {option:?}" + ))) + } + } + } + + Ok(ColumnCatalog::new(column_name, nullable, column_desc)) + } + + fn bind_constraint( + &mut self, + table_columns: &mut [ColumnCatalog], + exprs: Vec, + fn_constraint: F, + ) -> Result<(), DatabaseError> { + for (i, index_column) in exprs.into_iter().enumerate() { + let Expr::Identifier(ident) = index_column.column.expr else { + return Err(DatabaseError::UnsupportedStmt( + "only identifier columns are supported in `PRIMARY KEY/UNIQUE`".to_string(), + )); + }; + let column_name = lower_ident(&ident); + + if let Some(column) = table_columns + .iter_mut() + .find(|column| column.name() == column_name.as_ref()) + { + fn_constraint(i, column.desc_mut()) + } + } + Ok(()) + } + + fn create_index( + self, + create: sqlparser::ast::CreateIndex, + ) -> Result { + let table_name = sql_table_name(create.table_name)?; + let index_name = create + .name + .ok_or(DatabaseError::InvalidIndex) + .and_then(sql_object_name)?; + let input = self + .binder + .bind_create_index_source(table_name.clone(), self.arena)?; + let mut columns = Vec::with_capacity(create.columns.len()); + + for index_column in create.columns { + match self + .binder + .bind_expr(&index_column.column.expr, self.arena)? + { + ScalarExpression::ColumnRef { column, .. } => columns.push(column), + expr => { + return Err(DatabaseError::UnsupportedStmt(format!( + "'CREATE INDEX' by {expr}" + ))) + } + } + } + + self.binder.bind_create_index( + table_name, + index_name, + columns, + create.if_not_exists, + create.unique, + input, + ) + } + + fn create_view(self, create: CreateView) -> Result { + let CreateView { + or_replace, + name, + columns, + query, + .. + } = create; + let output_aliases = query_output_aliases(&query); + let view_name = sql_table_name(name)?; + let column_names = columns + .into_iter() + .map(|column| lower_ident(&column.name).into_owned()) + .collect(); + let plan = self.binder.bind_query(query.as_ref(), self.arena)?; + + self.binder.bind_create_view( + view_name, + or_replace, + plan, + column_names, + output_aliases, + self.arena, + ) + } + + fn drop_object( + self, + object_type: ObjectType, + mut names: Vec, + if_exists: bool, + ) -> Result { + if names.len() > 1 { + return Err(DatabaseError::UnsupportedStmt( + "only Drop a single `Table` or `View` is allowed".to_string(), + )); + } + + match object_type { + ObjectType::Table => self + .binder + .bind_drop_table(sql_table_name(names.remove(0))?, if_exists), + ObjectType::View => self + .binder + .bind_drop_view(sql_table_name(names.remove(0))?, if_exists), + ObjectType::Index => { + let (table_name, index_name) = sql_index_name(names.remove(0))?; + self.binder + .bind_drop_index(table_name, index_name, if_exists) + } + _ => Err(DatabaseError::UnsupportedStmt( + "only `Table` and `View` are allowed to be Dropped".to_string(), + )), + } + } + + #[cfg(feature = "copy")] + fn copy( + self, + source: CopySource, + to: bool, + target: CopyTarget, + options: Vec, + ) -> Result { + let ext_source = copy_ext_source(target, options)?; + + match source { + CopySource::Table { table_name, .. } => { + self.binder + .bind_copy_table(sql_table_name(table_name)?, to, ext_source, self.arena) + } + CopySource::Query(query) => { + if !to { + return Err(DatabaseError::UnsupportedStmt( + "'COPY FROM query'".to_string(), + )); + } + let input_plan = self.binder.bind_query(query.as_ref(), self.arena)?; + self.binder.bind_copy_to_file(ext_source, input_plan) + } + } + } + + fn insert(mut self, insert: &sqlparser::ast::Insert) -> Result { + let sqlparser::ast::Insert { + table, + columns, + overwrite, + source, + .. + } = insert; + let table_name = match table { + TableObject::TableName(table_name) => table_name.clone(), + TableObject::TableFunction(_) => { + return Err(DatabaseError::UnsupportedStmt( + "insert into table function is not supported".to_string(), + )) + } + }; + let table_name = sql_table_name(table_name)?; + let source = source.as_ref().ok_or_else(|| { + DatabaseError::UnsupportedStmt("insert without source is not supported".to_string()) + })?; + if let SetExpr::Values(values) = source.body.as_ref() { + self.insert_values(table_name, columns, &values.rows, *overwrite, false) + } else { + self.insert_query(table_name, columns, source, *overwrite) + } + } + + fn insert_values( + &mut self, + table_name: TableName, + idents: &[Ident], + expr_rows: &[Vec], + is_overwrite: bool, + is_mapping_by_name: bool, + ) -> Result { + self.binder.context.allow_default = true; + let source = self + .binder + .context + .source_and_bind(table_name.clone(), None, None, false)? + .ok_or(DatabaseError::TableNotFound)?; + let values_len = expr_rows[0].len(); + + let schema_ref = if idents.is_empty() { + if values_len > source.schema_len() { + return Err(DatabaseError::ValuesLenMismatch( + source.schema_len(), + values_len, + )); + } + source.schema().to_vec() + } else { + let mut columns = Vec::with_capacity(idents.len()); + for ident in idents { + match self.binder.bind_column_ref_from_identifiers( + slice::from_ref(ident), + Some(table_name.as_ref()), + self.arena, + )? { + ScalarExpression::ColumnRef { column, .. } => columns.push(column), + _ => return Err(DatabaseError::UnsupportedStmt(ident.to_string())), + } + } + if values_len != columns.len() { + return Err(DatabaseError::ValuesLenMismatch(columns.len(), values_len)); + } + columns + }; + let mut rows = Vec::with_capacity(expr_rows.len()); + + for expr_row in expr_rows { + if expr_row.len() != values_len { + return Err(DatabaseError::ValuesLenMismatch(expr_row.len(), values_len)); + } + let mut row = Vec::with_capacity(expr_row.len()); + + for (i, expr) in expr_row.iter().enumerate() { + let span = expr.span(); + let mut expression = self.binder.bind_expr(expr, self.arena)?; + + ConstantCalculator::new(self.arena).visit(&mut expression)?; + match expression { + ScalarExpression::Constant(mut value) => { + let column = self.arena.column(schema_ref[i]); + let ty = column.datatype(); + + value = value.cast(ty)?; + value.check_len(ty)?; + if value.is_null() && !column.nullable() { + return Err(attach_span_if_absent( + DatabaseError::not_null_column(column.name().to_string()), + span, + )); + } + + row.push(value); + } + ScalarExpression::Empty => { + let column = self.arena.column(schema_ref[i]); + let default_value = column + .default_value()? + .ok_or(DatabaseError::DefaultNotExist)?; + if default_value.is_null() && !column.nullable() { + return Err(attach_span_if_absent( + DatabaseError::not_null_column(column.name().to_string()), + span, + )); + } + row.push(default_value); + } + _ => { + return Err(attach_span_if_absent( + DatabaseError::UnsupportedStmt( + "INSERT values must be constants or DEFAULT".to_string(), + ), + span, + )) + } + } + } + rows.push(row); + } + self.binder.context.allow_default = false; + + self.binder.bind_insert_values( + table_name, + schema_ref, + rows, + is_overwrite, + is_mapping_by_name, + ) + } + + fn insert_query( + &mut self, + table_name: TableName, + idents: &[Ident], + query: &Query, + is_overwrite: bool, + ) -> Result { + let mut input_plan = self.binder.bind_query(query, self.arena)?; + let input_schema = input_plan.output_schema(self.arena).clone(); + let input_len = input_schema.len(); + + let projection = { + let source = self + .binder + .context + .source(&table_name)? + .ok_or(DatabaseError::TableNotFound)?; + + if idents.is_empty() { + let table_schema = source.schema(); + if input_len > table_schema.len() { + return Err(DatabaseError::ValuesLenMismatch( + table_schema.len(), + input_len, + )); + } + table_schema[..input_len] + .iter() + .copied() + .enumerate() + .map(|(position, target_column)| ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr( + input_schema[position], + position, + )), + alias: AliasType::Name(self.arena.column(target_column).name().to_string()), + }) + .collect::>() + } else { + if input_len != idents.len() { + return Err(DatabaseError::ValuesLenMismatch(idents.len(), input_len)); + } + let mut projection = Vec::with_capacity(idents.len()); + for (position, ident) in idents.iter().enumerate() { + let column_name = lower_ident(ident); + let column = source.column(&column_name, self.arena).ok_or_else(|| { + attach_span_if_absent( + DatabaseError::column_not_found(column_name), + ident.span, + ) + })?; + projection.push(ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr( + input_schema[position], + position, + )), + alias: AliasType::Name(self.arena.column(column).name().to_string()), + }); + } + projection + } + }; + input_plan = self + .binder + .bind_project(input_plan, projection, self.arena)?; + + self.binder + .bind_insert_query(table_name, input_plan, is_overwrite) + } + + fn update(self, update: &sqlparser::ast::Update) -> Result { + self.binder.context.allow_default = true; + let to = &update.table; + if let TableFactor::Table { name, .. } = &to.relation { + let is_joined_update = !to.joins.is_empty(); + let table_name = sql_table_name(name.clone())?; + self.binder.with_pk(table_name.clone()); + + let mut plan = self.binder.bind_table_ref_sql(to, self.arena)?; + let (target_source, target_offset) = + Binder::<'a, 'b, T, A>::resolve_source_columns_in_scope( + &self.binder.context, + &table_name, + )?; + let target_schema = target_source.schema().to_vec(); + + if let Some(predicate) = update.selection.as_ref() { + plan = self.binder.bind_where(plan, predicate, self.arena)?; + } + let mut value_exprs = Vec::with_capacity(update.assignments.len()); + + if update.assignments.is_empty() { + return Err(DatabaseError::ColumnsEmpty); + } + for Assignment { target, value } in &update.assignments { + let expression = self.binder.bind_expr(value, self.arena)?; + let mut bind_assignment = |name: &ObjectName, + expression: ScalarExpression| + -> Result<(), DatabaseError> { + let ident = single_ident_from_object_name(name)?; + + let column = { + match self.binder.bind_column_ref_from_identifiers( + slice::from_ref(&ident), + Some(table_name.as_ref()), + self.arena, + )? { + ScalarExpression::ColumnRef { column, .. } => column, + _ => { + return Err(attach_span_if_absent( + DatabaseError::invalid_column(ident.to_string()), + ident.span, + )) + } + } + }; + + let mut expr = if matches!(expression, ScalarExpression::Empty) { + let column_catalog = self.arena.column(column); + let default_value = column_catalog + .default_value()? + .ok_or(DatabaseError::DefaultNotExist)?; + ScalarExpression::Constant(default_value) + } else { + expression + }; + let column_catalog = self.arena.column(column); + expr = ScalarExpression::type_cast( + expr, + Cow::Borrowed(column_catalog.datatype()), + self.arena, + )?; + if is_joined_update { + UpdateExprTargetRemapper { + target_schema: &target_schema, + arena: self.arena, + } + .visit(&mut expr)?; + } + value_exprs.push((column, expr)); + Ok(()) + }; + + match target { + AssignmentTarget::ColumnName(name) => bind_assignment(name, expression)?, + AssignmentTarget::Tuple(names) => { + let expected = names.len(); + let ScalarExpression::Tuple(exprs) = expression else { + return Err(DatabaseError::ValuesLenMismatch(expected, 1)); + }; + let got = exprs.len(); + let mut names = names.iter(); + let mut exprs = exprs.into_iter(); + + loop { + match (names.next(), exprs.next()) { + (Some(name), Some(expression)) => { + bind_assignment(name, expression)? + } + (None, None) => break, + _ => return Err(DatabaseError::ValuesLenMismatch(expected, got)), + } + } + } + } + } + self.binder.context.allow_default = false; + if is_joined_update { + let exprs = target_schema + .iter() + .copied() + .enumerate() + .map(|(index, column)| { + ScalarExpression::column_expr(column, target_offset + index) + }) + .collect(); + plan = LogicalPlan::new( + Operator::Project(ProjectOperator { exprs }), + Childrens::Only(Box::new(plan)), + ); + } + self.binder.bind_update(table_name, value_exprs, plan) + } else { + Err(DatabaseError::UnsupportedStmt(format!( + "UPDATE target must be a table: {:?}", + to.relation + ))) + } + } + + fn delete(self, delete: &sqlparser::ast::Delete) -> Result { + let from = match &delete.from { + FromTable::WithFromKeyword(from) | FromTable::WithoutKeyword(from) => from, + }; + let table = from + .iter() + .next() + .ok_or_else(|| DatabaseError::invalid_table("DELETE without FROM"))?; + + if let TableFactor::Table { name, .. } = &table.relation { + let table_name = sql_table_name(name.clone())?; + let primary_keys = self + .binder + .context + .table(table_name.clone())? + .ok_or(DatabaseError::TableNotFound)? + .primary_keys() + .iter() + .map(|(_, column)| *column) + .collect(); + self.binder.with_pk(table_name.clone()); + let mut plan = self.binder.bind_table_ref_sql(table, self.arena)?; + + if let Some(predicate) = delete.selection.as_ref() { + plan = self.binder.bind_where(plan, predicate, self.arena)?; + } + + self.binder.bind_delete(table_name, primary_keys, plan) + } else { + Err(DatabaseError::UnsupportedStmt(format!( + "DELETE target must be a table: {:?}", + table.relation + ))) + } + } + + fn analyze(self, analyze: sqlparser::ast::Analyze) -> Result { + let table_name = analyze.table_name.ok_or_else(|| { + DatabaseError::UnsupportedStmt("ANALYZE without table is not supported".to_string()) + })?; + self.binder + .bind_analyze(sql_table_name(table_name)?, self.arena) + } + + fn truncate(self, truncate: sqlparser::ast::Truncate) -> Result { + if truncate.table_names.len() != 1 { + return Err(DatabaseError::UnsupportedStmt( + "only truncate a single table is supported".to_string(), + )); + } + self.binder.bind_truncate(sql_table_name( + truncate.table_names.into_iter().next().unwrap().name, + )?) + } + + fn explain(self, statement: &Statement) -> Result { + let BindStatementStart { binder, arena } = self; + let plan = binder.bind(statement, arena)?; + binder.bind_explain(plan) + } +} + +impl BindStatementComplete { + fn finish(self) -> LogicalPlan { + self.plan + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanStart<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + #[allow(clippy::wrong_self_convention)] + pub(crate) fn from_sql( + self, + from: &[TableWithJoins], + ) -> Result, DatabaseError> { + let mut froms = from.iter(); + let mut plan = if let Some(from) = froms.next() { + let mut plan = self.binder.bind_table_ref_sql(from, self.arena)?; + + for from in froms { + plan = LJoinOperator::build( + plan, + self.binder.bind_table_ref_sql(from, self.arena)?, + JoinCondition::None, + JoinType::Cross, + ) + } + plan + } else { + LogicalPlan::new(Operator::Dummy, Childrens::None) + }; + plan.output_schema(self.arena); + + self.from_plan(plan) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanFrom<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn select_list_from_sql( + self, + projection: &[SelectItem], + ) -> Result, DatabaseError> { + let select_list = with_query_bind_step!(self.binder, QueryBindStep::Project, { + self.binder.normalize_select_item(projection, self.arena)? + }); + + Ok(self.select_list(select_list?)) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanSelectList<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn where_sql( + self, + selection: Option<&Expr>, + ) -> Result, DatabaseError> { + let predicate = if let Some(predicate) = selection { + Some(with_query_bind_step!(self.binder, QueryBindStep::Where, { + self.binder.bind_expr(predicate, self.arena)? + })?) + } else { + None + }; + + self.filter_expr(predicate) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanFiltered<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn aggregate_sql( + self, + group_by: &GroupByExpr, + having: Option<&Expr>, + orderby: Option<&[OrderByExpr]>, + ) -> Result, DatabaseError> { + let group_by = with_query_bind_step!(self.binder, QueryBindStep::Agg, { + match group_by { + GroupByExpr::Expressions(group_by_exprs, modifiers) => { + if !modifiers.is_empty() { + return Err(DatabaseError::UnsupportedStmt( + "GROUP BY modifiers are not supported".to_string(), + )); + } + group_by_exprs + .iter() + .map(|expr| self.binder.bind_expr(expr, self.arena)) + .collect::, DatabaseError>>()? + } + GroupByExpr::All(_) => { + return Err(DatabaseError::UnsupportedStmt( + "GROUP BY ALL is not supported".to_string(), + )) + } + } + })?; + let having = having + .map(|having| { + with_query_bind_step!(self.binder, QueryBindStep::Having, { + self.binder.bind_expr(having, self.arena)? + }) + }) + .transpose()?; + self.aggregate(group_by, having, orderby, |binder, arena, orderby| { + let OrderByExpr { expr, options, .. } = orderby; + with_query_bind_step!(binder, QueryBindStep::Sort, { + SortField::new( + binder.bind_expr(expr, arena)?, + options.asc.is_none_or(|asc| asc), + options.nulls_first.unwrap_or(false), + ) + }) + }) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanHaving<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn distinct_sql( + self, + distinct: Option<&Distinct>, + ) -> Result, DatabaseError> { + self.distinct(matches!(distinct, Some(Distinct::Distinct))) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanProjected<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn select_into_sql( + self, + into: Option<&SelectInto>, + ) -> Result { + self.insert_into( + into.map(|SelectInto { name, .. }| lower_case_name(name).map(Into::into)) + .transpose()?, + ) + } +} + +fn sql_table_name(name: ObjectName) -> Result { + Ok(lower_case_name(&name)?.into()) +} + +fn sql_table_alias(alias: TableAlias) -> TableAliasInput { + TableAliasInput { + name: lower_ident(&alias.name).into(), + columns: alias + .columns + .into_iter() + .map(|column| lower_ident(&column.name).into_owned()) + .collect(), + } +} + +fn sql_optional_table_alias(alias: Option) -> Option { + alias.map(sql_table_alias) +} + +impl From for CharLengthUnits { + fn from(value: sqlparser::ast::CharLengthUnits) -> Self { + match value { + sqlparser::ast::CharLengthUnits::Characters => Self::Characters, + sqlparser::ast::CharLengthUnits::Octets => Self::Octets, + } + } +} + +impl From for expression::TrimWhereField { + fn from(value: sqlparser::ast::TrimWhereField) -> Self { + match value { + sqlparser::ast::TrimWhereField::Both => Self::Both, + sqlparser::ast::TrimWhereField::Leading => Self::Leading, + sqlparser::ast::TrimWhereField::Trailing => Self::Trailing, + } + } +} + +impl TryFrom for expression::UnaryOperator { + type Error = DatabaseError; + + fn try_from(value: UnaryOperator) -> Result { + match value { + UnaryOperator::Plus => Ok(Self::Plus), + UnaryOperator::Minus => Ok(Self::Minus), + UnaryOperator::Not => Ok(Self::Not), + op => Err(DatabaseError::UnsupportedStmt(format!("{op}"))), + } + } +} + +impl TryFrom for expression::BinaryOperator { + type Error = DatabaseError; + + fn try_from(value: BinaryOperator) -> Result { + match value { + BinaryOperator::Plus => Ok(Self::Plus), + BinaryOperator::Minus => Ok(Self::Minus), + BinaryOperator::Multiply => Ok(Self::Multiply), + BinaryOperator::Divide => Ok(Self::Divide), + BinaryOperator::Modulo => Ok(Self::Modulo), + BinaryOperator::StringConcat => Ok(Self::StringConcat), + BinaryOperator::Gt => Ok(Self::Gt), + BinaryOperator::Lt => Ok(Self::Lt), + BinaryOperator::GtEq => Ok(Self::GtEq), + BinaryOperator::LtEq => Ok(Self::LtEq), + BinaryOperator::Spaceship => Ok(Self::Spaceship), + BinaryOperator::Eq => Ok(Self::Eq), + BinaryOperator::NotEq => Ok(Self::NotEq), + BinaryOperator::And => Ok(Self::And), + BinaryOperator::Or => Ok(Self::Or), + op => Err(DatabaseError::UnsupportedStmt(format!("{op}"))), + } + } +} + +impl TryFrom<&Value> for DataValue { + type Error = DatabaseError; + + fn try_from(value: &Value) -> Result { + Ok(match value { + Value::Number(n, _) => { + // use i32 to handle most cases + if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else { + return Err(DatabaseError::InvalidValue(n.to_string())); + } + } + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => s.clone().into(), + Value::Boolean(b) => (*b).into(), + Value::Null => Self::Null, + v => return Err(DatabaseError::UnsupportedStmt(format!("{v:?}"))), + }) + } +} + +impl TryFrom for LogicalType { + type Error = DatabaseError; + + fn try_from(value: DataType) -> Result { + match value { + DataType::Char(char_len) | DataType::Character(char_len) => { + let mut len = 1; + let mut char_unit = None; + if let Some(char_len) = char_len { + match char_len { + sqlparser::ast::CharacterLength::IntegerLength { length, unit } => { + len = cmp::max(len, length); + char_unit = unit; + } + sqlparser::ast::CharacterLength::Max => { + return Err(DatabaseError::UnsupportedStmt( + "CHAR(MAX) is not supported".to_string(), + )); + } + } + } + Ok(Self::Char( + len as u32, + char_unit + .map(Into::into) + .unwrap_or(CharLengthUnits::Characters), + )) + } + DataType::CharVarying(varchar_len) + | DataType::CharacterVarying(varchar_len) + | DataType::Varchar(varchar_len) => { + let mut len = None; + let mut char_unit = None; + if let Some(varchar_len) = varchar_len { + match varchar_len { + sqlparser::ast::CharacterLength::IntegerLength { length, unit } => { + len = Some(length as u32); + char_unit = unit; + } + sqlparser::ast::CharacterLength::Max => { + return Err(DatabaseError::UnsupportedStmt( + "VARCHAR(MAX) is not supported".to_string(), + )); + } + } + } + Ok(Self::Varchar( + len, + char_unit + .map(Into::into) + .unwrap_or(CharLengthUnits::Characters), + )) + } + DataType::String(_) | DataType::Text => { + Ok(Self::Varchar(None, CharLengthUnits::Characters)) + } + DataType::Float(_) | DataType::Float4 | DataType::Float32 | DataType::Real => { + Ok(Self::Float) + } + DataType::Double(_) + | DataType::DoublePrecision + | DataType::Float8 + | DataType::Float64 => Ok(Self::Double), + DataType::TinyInt(_) => Ok(Self::Tinyint), + DataType::TinyIntUnsigned(_) | DataType::UTinyInt => Ok(Self::UTinyint), + DataType::SmallInt(_) | DataType::Int2(_) => Ok(Self::Smallint), + DataType::SmallIntUnsigned(_) | DataType::Int2Unsigned(_) | DataType::USmallInt => { + Ok(Self::USmallint) + } + DataType::Int(_) | DataType::Integer(_) | DataType::Int4(_) | DataType::Int32 => { + Ok(Self::Integer) + } + DataType::IntUnsigned(_) + | DataType::IntegerUnsigned(_) + | DataType::Int4Unsigned(_) + | DataType::Unsigned + | DataType::UnsignedInteger + | DataType::UInt32 => Ok(Self::UInteger), + DataType::BigInt(_) | DataType::Int8(_) | DataType::Int64 => Ok(Self::Bigint), + DataType::BigIntUnsigned(_) + | DataType::Int8Unsigned(_) + | DataType::UBigInt + | DataType::UInt64 => Ok(Self::UBigint), + DataType::Boolean => Ok(Self::Boolean), + DataType::Date => { + #[cfg(feature = "time")] + { + Ok(Self::Date) + } + #[cfg(not(feature = "time"))] + { + Err(DatabaseError::UnsupportedStmt( + "time types require the `time` feature".to_string(), + )) + } + } + DataType::Datetime(precision) => { + #[cfg(feature = "time")] + { + if precision.is_some() { + return Err(DatabaseError::UnsupportedStmt( + "time's precision".to_string(), + )); + } + Ok(Self::DateTime) + } + #[cfg(not(feature = "time"))] + { + let _ = precision; + Err(DatabaseError::UnsupportedStmt( + "time types require the `time` feature".to_string(), + )) + } + } + DataType::Time(precision, info) => { + #[cfg(feature = "time")] + { + match precision { + Some(0..5) | None => (), + _ => { + return Err(DatabaseError::UnsupportedStmt( + "time's precision must be less than 5".to_string(), + )) + } + } + if !matches!(info, sqlparser::ast::TimezoneInfo::None) { + return Err(DatabaseError::UnsupportedStmt( + "time's zone is not supported".to_string(), + )); + } + Ok(Self::Time(precision)) + } + #[cfg(not(feature = "time"))] + { + let _ = (precision, info); + Err(DatabaseError::UnsupportedStmt( + "time types require the `time` feature".to_string(), + )) + } + } + DataType::Timestamp(precision, info) => { + #[cfg(feature = "time")] + { + let mut zone = false; + match precision { + Some(3 | 6 | 9) | None => (), + _ => { + return Err(DatabaseError::UnsupportedStmt( + "timestamp's precision must be 3,6,9".to_string(), + )) + } + } + if matches!(info, sqlparser::ast::TimezoneInfo::WithTimeZone) { + zone = true; + } + Ok(Self::TimeStamp(precision, zone)) + } + #[cfg(not(feature = "time"))] + { + let _ = (precision, info); + Err(DatabaseError::UnsupportedStmt( + "time types require the `time` feature".to_string(), + )) + } + } + DataType::Decimal(info) + | DataType::DecimalUnsigned(info) + | DataType::Dec(info) + | DataType::DecUnsigned(info) + | DataType::Numeric(info) => { + #[cfg(feature = "decimal")] + { + match info { + sqlparser::ast::ExactNumberInfo::None => Ok(Self::Decimal(None, None)), + sqlparser::ast::ExactNumberInfo::Precision(p) => { + Ok(Self::Decimal(Some(p as u8), None)) + } + sqlparser::ast::ExactNumberInfo::PrecisionAndScale(p, s) => { + Ok(Self::Decimal(Some(p as u8), Some(s as u8))) + } + } + } + #[cfg(not(feature = "decimal"))] + { + let _ = info; + Err(DatabaseError::UnsupportedStmt( + "DECIMAL requires the `decimal` feature".to_string(), + )) + } + } + other => Err(DatabaseError::UnsupportedStmt(format!( + "unsupported data type: {other}" + ))), + } + } +} + +fn sql_object_name(name: ObjectName) -> Result { + Ok(lower_case_name(&name)?.into_owned()) +} + +pub(super) fn lower_ident(ident: &Ident) -> Cow<'_, str> { + let value = &ident.value; + + if value.chars().any(char::is_uppercase) { + Cow::Owned(value.to_lowercase()) + } else { + Cow::Borrowed(value) + } +} + +fn lower_name_part(part: &ObjectNamePart) -> Result, DatabaseError> { + part.as_ident() + .map(lower_ident) + .ok_or_else(|| attach_span_if_absent(DatabaseError::invalid_table(part.to_string()), part)) +} + +/// Convert an object name into lower case. +pub(super) fn lower_case_name(name: &ObjectName) -> Result, DatabaseError> { + if name.0.len() == 1 { + return lower_name_part(&name.0[0]); + } + Err(attach_span_if_absent( + DatabaseError::invalid_table(name.to_string()), + name, + )) +} + +fn single_ident_from_object_name(name: &ObjectName) -> Result { + if name.0.len() != 1 { + return Err(attach_span_if_absent( + DatabaseError::invalid_column(name.to_string()), + name, + )); + } + match name.0.first() { + Some(ObjectNamePart::Identifier(ident)) => Ok(ident.clone()), + Some(part) => Err(DatabaseError::invalid_column(part.to_string())), + None => Err(DatabaseError::invalid_column(String::new())), + } +} + +fn sql_index_name(name: ObjectName) -> Result<(TableName, String), DatabaseError> { + let table_name = name.0.first().ok_or_else(|| { + attach_span_if_absent(DatabaseError::invalid_table(name.to_string()), &name) + })?; + let index_name = name.0.get(1).ok_or(DatabaseError::InvalidIndex)?; + + Ok(( + lower_name_part(table_name)?.into(), + lower_name_part(index_name)?.into_owned(), + )) +} + +fn query_output_aliases(query: &Query) -> Vec> { + let SetExpr::Select(select) = query.body.as_ref() else { + return Vec::new(); + }; + + select + .projection + .iter() + .map(|item| match item { + SelectItem::ExprWithAlias { alias, .. } => Some(lower_ident(alias).into_owned()), + _ => None, + }) + .collect() +} + +#[cfg(feature = "copy")] +fn copy_ext_source( + target: CopyTarget, + options: Vec, +) -> Result { + Ok(ExtSource { + path: match target { + CopyTarget::File { filename } => filename.into(), + t => { + return Err(DatabaseError::UnsupportedStmt(format!( + "copy target: {t:?}" + ))) + } + }, + format: copy_file_format(options)?, + }) +} + +#[cfg(feature = "copy")] +fn copy_file_format(options: Vec) -> Result { + let mut delimiter = ','; + let mut quote = '"'; + let mut escape = None; + let mut header = false; + for opt in options { + match opt { + CopyOption::Format(fmt) => { + debug_assert_eq!(fmt.value.to_lowercase(), "csv", "only support CSV format") + } + CopyOption::Delimiter(c) => delimiter = c, + CopyOption::Header(b) => header = b, + CopyOption::Quote(c) => quote = c, + CopyOption::Escape(c) => escape = Some(c), + o => { + return Err(DatabaseError::UnsupportedStmt(format!( + "copy option: {o:?}" + ))) + } + } + } + Ok(FileFormat::Csv { + delimiter, + quote, + escape, + header, + }) +} + +impl<'a, 'parent, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'parent, T, A> { + fn bind_table_ref_sql( + &mut self, + from: &TableWithJoins, + arena: &mut PlanArena, + ) -> Result { + self.context.step(QueryBindStep::From); + + let TableWithJoins { relation, joins } = from; + let mut plan = self.bind_single_table_ref_sql(relation, None, arena)?; + + for join in joins { + plan = self.bind_join_sql(plan, join, arena)?; + } + Ok(plan) + } + + fn bind_single_table_ref_sql( + &mut self, + table: &TableFactor, + joint_type: Option, + arena: &mut PlanArena, + ) -> Result { + match table { + TableFactor::Table { name, alias, .. } => self.bind_base_table_ref( + joint_type, + sql_table_name(name.clone())?, + sql_optional_table_alias(alias.clone()), + arena, + ), + TableFactor::Derived { + subquery, alias, .. + } => { + let mut binder = Binder::new(self.context.fork(), self.args, Some(&self.context)); + let plan = binder.bind_query(subquery, arena)?; + self.bind_derived_source( + plan, + sql_optional_table_alias(alias.clone()), + joint_type, + arena, + ) + } + TableFactor::TableFunction { expr, alias } => { + let expr = self.bind_expr(expr, arena)?; + self.bind_table_function_source( + expr, + sql_optional_table_alias(alias.clone()), + joint_type, + arena, + ) + } + table => Err(DatabaseError::UnsupportedStmt(format!("{table:#?}"))), + } + } + + fn bind_join_sql( + &mut self, + left: LogicalPlan, + join: &Join, + arena: &mut PlanArena, + ) -> Result { + let Join { + relation, + join_operator, + .. + } = join; + + let (join_type, joint_condition) = match join_operator { + JoinOperator::Join(constraint) + | JoinOperator::Inner(constraint) + | JoinOperator::StraightJoin(constraint) => (JoinType::Inner, Some(constraint)), + JoinOperator::Left(constraint) | JoinOperator::LeftOuter(constraint) => { + (JoinType::LeftOuter, Some(constraint)) + } + JoinOperator::Right(constraint) | JoinOperator::RightOuter(constraint) => { + (JoinType::RightOuter, Some(constraint)) + } + JoinOperator::FullOuter(constraint) => (JoinType::Full, Some(constraint)), + JoinOperator::CrossJoin(constraint) => (JoinType::Cross, Some(constraint)), + JoinOperator::Semi(_) + | JoinOperator::LeftSemi(_) + | JoinOperator::Anti(_) + | JoinOperator::LeftAnti(_) + | JoinOperator::RightSemi(_) + | JoinOperator::RightAnti(_) + | JoinOperator::CrossApply + | JoinOperator::OuterApply + | JoinOperator::AsOf { .. } => { + return Err(DatabaseError::UnsupportedStmt(format!("{join_operator:?}"))) + } + }; + let (right, context) = { + let mut binder = Binder::new(self.context.fork_empty(), self.args, Some(&self.context)); + let right = binder.bind_single_table_ref_sql(relation, Some(join_type), arena)?; + (right, binder.context) + }; + self.extend(context); + + let constraint = match joint_condition { + Some(constraint) => self.bind_join_constraint_sql(constraint, arena)?, + None => JoinConstraintInput::None, + }; + + self.bind_join_plans(left, right, join_type, constraint, arena) + } + + fn bind_join_constraint_sql( + &mut self, + constraint: &JoinConstraint, + arena: &mut PlanArena, + ) -> Result { + match constraint { + JoinConstraint::On(expr) => Ok(JoinConstraintInput::On(self.bind_expr(expr, arena)?)), + JoinConstraint::Using(names) => Ok(JoinConstraintInput::Using( + names + .iter() + .map(|name| lower_case_name(name).map(Cow::into_owned)) + .collect::>()?, + )), + JoinConstraint::Natural => Ok(JoinConstraintInput::Natural), + JoinConstraint::None => Ok(JoinConstraintInput::None), + } + } + + fn parse_like_escape_char(escape_char: &Option) -> Result, DatabaseError> { + match escape_char { + None => Ok(None), + Some(value) => match value { + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => { + let mut chars = s.chars(); + let ch = chars.next().ok_or(DatabaseError::InvalidValue( + "escape character must not be empty".to_string(), + ))?; + if chars.next().is_some() { + return Err(DatabaseError::InvalidValue( + "escape character must be a single character".to_string(), + )); + } + Ok(Some(ch)) + } + _ => Err(DatabaseError::InvalidValue( + "escape character must be a quoted string".to_string(), + )), + }, + } + } + + pub(crate) fn bind_expr( + &mut self, + expr: &Expr, + arena: &mut PlanArena, + ) -> Result { + let expr_span = expr.span(); + match expr { + Expr::Identifier(ident) => { + self.bind_column_ref_from_identifiers(slice::from_ref(ident), None, arena) + } + Expr::CompoundIdentifier(idents) => { + self.bind_column_ref_from_identifiers(idents, None, arena) + } + Expr::BinaryOp { left, right, op } => { + let left_expr = self.bind_expr(left, arena)?; + let right_expr = self.bind_expr(right, arena)?; + self.bind_binary_op_expr(left_expr, right_expr, op.clone().try_into()?, arena) + } + Expr::Value(v) => { + let value = if let Value::Placeholder(name) = &v.value { + self.args + .as_ref() + .iter() + .find_map(|(key, value)| (key == name).then(|| value.clone())) + .ok_or_else(|| { + attach_span_if_absent( + DatabaseError::parameter_not_found(name.to_string()), + v, + ) + })? + } else { + (&v.value) + .try_into() + .map_err(|err| attach_span_if_absent(err, v))? + }; + Ok(ScalarExpression::Constant(value)) + } + Expr::Function(func) => self.bind_function_sql(func, arena), + Expr::Nested(expr) => self.bind_expr(expr, arena), + Expr::UnaryOp { expr, op } => { + let expr = self.bind_expr(expr, arena)?; + self.bind_unary_op_expr(expr, (*op).try_into()?, arena) + } + Expr::Like { + negated, + expr, + pattern, + escape_char, + any: _, + } => { + let left_expr = Box::new(self.bind_expr(expr, arena)?); + let right_expr = Box::new(self.bind_expr(pattern, arena)?); + let escape_char = Self::parse_like_escape_char(escape_char)?; + let op = if *negated { + expression::BinaryOperator::NotLike(escape_char) + } else { + expression::BinaryOperator::Like(escape_char) + }; + Ok(ScalarExpression::Binary { + op, + left_expr, + right_expr, + evaluator: None, + ty: LogicalType::Boolean, + }) + } + Expr::IsNull(expr) => Ok(ScalarExpression::IsNull { + negated: false, + expr: Box::new(self.bind_expr(expr, arena)?), + }), + Expr::IsNotNull(expr) => Ok(ScalarExpression::IsNull { + negated: true, + expr: Box::new(self.bind_expr(expr, arena)?), + }), + Expr::InList { + expr, + list, + negated, + } => { + let args = list + .iter() + .map(|expr| self.bind_expr(expr, arena)) + .try_collect()?; + Ok(ScalarExpression::In { + negated: *negated, + expr: Box::new(self.bind_expr(expr, arena)?), + args, + }) + } + Expr::Cast { + expr, data_type, .. + } => ScalarExpression::type_cast( + self.bind_expr(expr, arena)?, + Cow::Owned(LogicalType::try_from(data_type.clone())?), + arena, + ), + Expr::TypedString(TypedString { + data_type, value, .. + }) => { + let logical_type = LogicalType::try_from(data_type.clone())?; + let raw = value.clone().into_string().ok_or_else(|| { + DatabaseError::InvalidValue("typed string literal must be a string".to_string()) + })?; + let value = DataValue::Utf8 { + value: raw, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + } + .cast(&logical_type) + .map_err(|err| attach_span_if_absent(err, expr_span))?; + + Ok(ScalarExpression::Constant(value)) + } + Expr::Between { + expr, + negated, + low, + high, + } => Ok(ScalarExpression::Between { + negated: *negated, + expr: Box::new(self.bind_expr(expr, arena)?), + left_expr: Box::new(self.bind_expr(low, arena)?), + right_expr: Box::new(self.bind_expr(high, arena)?), + }), + Expr::Substring { + expr, + substring_for, + substring_from, + .. + } => { + let mut for_expr = None; + let mut from_expr = None; + + if let Some(expr) = substring_for { + for_expr = Some(Box::new(self.bind_expr(expr, arena)?)) + } + if let Some(expr) = substring_from { + from_expr = Some(Box::new(self.bind_expr(expr, arena)?)) + } + + Ok(ScalarExpression::SubString { + expr: Box::new(self.bind_expr(expr, arena)?), + for_expr, + from_expr, + }) + } + Expr::Position { expr, r#in } => Ok(ScalarExpression::Position { + expr: Box::new(self.bind_expr(expr, arena)?), + in_expr: Box::new(self.bind_expr(r#in, arena)?), + }), + Expr::Trim { + expr, + trim_what, + trim_where, + .. + } => { + let mut trim_what_expr = None; + if let Some(trim_what) = trim_what { + trim_what_expr = Some(Box::new(self.bind_expr(trim_what, arena)?)) + } + Ok(ScalarExpression::Trim { + expr: Box::new(self.bind_expr(expr, arena)?), + trim_what_expr, + trim_where: (*trim_where).map(Into::into), + }) + } + Expr::Exists { subquery, negated } => { + self.bind_exists_subquery_plan(*negated, arena, |binder, arena| { + binder.bind_query(subquery, arena) + }) + } + Expr::Subquery(subquery) => self.bind_scalar_subquery_plan(arena, |binder, arena| { + binder.bind_query(subquery, arena) + }), + Expr::InSubquery { + expr, + subquery, + negated, + } => self.bind_quantified_subquery( + MarkApplyQuantifier::Any, + *negated, + expr, + &BinaryOperator::Eq, + subquery, + arena, + ), + Expr::Tuple(exprs) => { + let mut bound_exprs = Vec::with_capacity(exprs.len()); + + for expr in exprs { + bound_exprs.push(self.bind_expr(expr, arena)?); + } + Ok(ScalarExpression::Tuple(bound_exprs)) + } + Expr::Case { + operand, + conditions, + else_result, + .. + } => { + let fn_check_ty = |ty: &mut LogicalType, result_ty| { + if result_ty != LogicalType::SqlNull { + if ty == &LogicalType::SqlNull { + *ty = result_ty; + } else if ty != &result_ty { + return Err(DatabaseError::Incomparable(ty.clone(), result_ty)); + } + } + + Ok(()) + }; + let mut operand_expr = None; + let mut ty = LogicalType::SqlNull; + if let Some(expr) = operand { + operand_expr = Some(Box::new(self.bind_expr(expr, arena)?)); + } + let mut expr_pairs = Vec::with_capacity(conditions.len()); + for when in conditions { + let result = self.bind_expr(&when.result, arena)?; + let result_ty = result.return_type(arena).into_owned(); + + fn_check_ty(&mut ty, result_ty)?; + expr_pairs.push((self.bind_expr(&when.condition, arena)?, result)) + } + + let mut else_expr = None; + if let Some(expr) = else_result { + let temp_expr = Box::new(self.bind_expr(expr, arena)?); + let else_ty = temp_expr.return_type(arena).into_owned(); + + fn_check_ty(&mut ty, else_ty)?; + else_expr = Some(temp_expr); + } + + Ok(ScalarExpression::CaseWhen { + operand_expr, + expr_pairs, + else_expr, + ty, + }) + } + Expr::AnyOp { + left, + compare_op, + right, + .. + } => self.bind_quantified_op(MarkApplyQuantifier::Any, left, compare_op, right, arena), + Expr::AllOp { + left, + compare_op, + right, + } => self.bind_quantified_op(MarkApplyQuantifier::All, left, compare_op, right, arena), + expr => Err(DatabaseError::UnsupportedStmt(expr.to_string())), + } + } + + fn bind_quantified_op( + &mut self, + quantifier: MarkApplyQuantifier, + left: &Expr, + compare_op: &BinaryOperator, + right: &Expr, + arena: &mut PlanArena, + ) -> Result { + let Expr::Subquery(subquery) = right else { + return Err(DatabaseError::UnsupportedStmt(format!( + "{quantifier:?} only supports subquery operands" + ))); + }; + + self.bind_quantified_subquery(quantifier, false, left, compare_op, subquery, arena) + } + + fn bind_quantified_subquery( + &mut self, + quantifier: MarkApplyQuantifier, + negated: bool, + expr: &Expr, + compare_op: &BinaryOperator, + subquery: &Query, + arena: &mut PlanArena, + ) -> Result { + let left_expr = self.bind_expr(expr, arena)?; + self.bind_quantified_subquery_plan( + quantifier, + negated, + left_expr, + compare_op.clone().try_into()?, + arena, + |binder, arena| binder.bind_query(subquery, arena), + ) + } + + pub fn bind_column_ref_from_identifiers( + &mut self, + idents: &[Ident], + bind_table_name: Option<&str>, + arena: &mut PlanArena, + ) -> Result { + let full_name = match idents { + [column] => (None, lower_ident(column)), + [table, column] => (Some(lower_ident(table)), lower_ident(column)), + _ => { + let invalid_name = idents + .iter() + .map(|ident| ident.value.clone()) + .join(".") + .to_string(); + let err = DatabaseError::invalid_column(invalid_name); + return Err(match idents.last() { + Some(ident) => attach_span_if_absent(err, ident.span), + None => err, + }); + } + }; + self.bind_column_ref_by_name( + full_name.0.as_deref(), + full_name.1.as_ref(), + bind_table_name, + arena, + ) + .map_err(|err| match idents.last() { + Some(ident) => attach_span_if_absent(err, ident.span), + None => err, + }) + } + + pub(crate) fn bind_function_sql( + &mut self, + func: &Function, + arena: &mut PlanArena, + ) -> Result { + let func_span = func.span(); + let Function { name, args, .. } = func; + let (func_args, is_distinct) = match args { + FunctionArguments::List(args) => ( + args.args.as_slice(), + matches!(args.duplicate_treatment, Some(DuplicateTreatment::Distinct)), + ), + FunctionArguments::None => (&[][..], false), + FunctionArguments::Subquery(_) => { + return Err(DatabaseError::UnsupportedStmt( + "subquery function args are not supported".to_string(), + )) + } + }; + let mut args = Vec::with_capacity(func_args.len()); + + for arg in func_args { + let arg_expr = match arg { + FunctionArg::Named { arg, .. } => arg, + FunctionArg::ExprNamed { arg, .. } => arg, + FunctionArg::Unnamed(arg) => arg, + }; + match arg_expr { + FunctionArgExpr::Expr(expr) => args.push(self.bind_expr(expr, arena)?), + FunctionArgExpr::Wildcard => args.push(Self::wildcard_expr()), + expr => { + return Err(DatabaseError::UnsupportedStmt(format!( + "function arg: {expr:#?}" + ))) + } + } + } + let function_name = name.to_string().to_lowercase(); + + self.bind_function_call(function_name, args, is_distinct, arena) + .map_err(|err| attach_span_if_absent(err, func_span)) + } + + pub fn bind_set_expr( + &mut self, + set_expr: &SetExpr, + arena: &mut PlanArena, + ) -> Result { + match set_expr { + SetExpr::Select(select) => self.bind_select(select, None, arena), + SetExpr::Query(query) => self.bind_query(query, arena), + SetExpr::SetOperation { + op, + set_quantifier, + left, + right, + } => self.bind_set_operation(op, set_quantifier, left, right, arena), + expr => Err(DatabaseError::UnsupportedStmt(format!( + "set expression: {expr:?}" + ))), + } + } + + fn bind_set_operation( + &mut self, + op: &SetOperator, + set_quantifier: &SetQuantifier, + left: &SetExpr, + right: &SetExpr, + arena: &mut PlanArena, + ) -> Result { + let is_all = match set_quantifier { + SetQuantifier::All => true, + SetQuantifier::Distinct | SetQuantifier::None => false, + SetQuantifier::ByName | SetQuantifier::AllByName | SetQuantifier::DistinctByName => { + return Err(DatabaseError::UnsupportedStmt( + "set quantifier BY NAME is not supported".to_string(), + )) + } + }; + let op = match op { + SetOperator::Union => SetOperatorKind::Union, + SetOperator::Except => SetOperatorKind::Except, + SetOperator::Intersect => SetOperatorKind::Intersect, + op => { + return Err(DatabaseError::UnsupportedStmt(format!( + "set operator: {op:?}" + ))) + } + }; + let left_plan = { + let mut left_binder = Binder::new(self.context.fork(), self.args, self.parent); + let plan = left_binder.bind_set_expr(left, arena)?; + if left_binder.context.has_outer_refs() { + self.context.mark_outer_ref(); + } + plan + }; + + let right_plan = { + let mut right_binder = Binder::new(self.context.fork(), self.args, self.parent); + let plan = right_binder.bind_set_expr(right, arena)?; + if right_binder.context.has_outer_refs() { + self.context.mark_outer_ref(); + } + plan + }; + + self.bind_set_operation_plans(op, is_all, left_plan, right_plan, arena) + } + + pub(crate) fn bind_select( + &mut self, + select: &Select, + orderby: Option<&[OrderByExpr]>, + arena: &mut PlanArena, + ) -> Result { + let Select { + projection, + from, + selection, + group_by, + having, + distinct, + into, + .. + } = select; + Ok(self + .build_plan(arena) + .from_sql(from)? + .select_list_from_sql(projection)? + .where_sql(selection.as_ref())? + .aggregate_sql(group_by, having.as_ref(), orderby)? + .having()? + .distinct_sql(distinct.as_ref())? + .order_by()? + .project()? + .select_into_sql(into.as_ref())? + .finish()) + } + + /// FIXME: temp values need to register BindContext.bind_table + fn bind_temp_values( + &mut self, + expr_rows: &[Vec], + arena: &mut PlanArena, + ) -> Result { + let values_len = expr_rows[0].len(); + + let mut inferred_types: Vec> = vec![None; values_len]; + let mut rows = Vec::with_capacity(expr_rows.len()); + + for expr_row in expr_rows { + if expr_row.len() != values_len { + return Err(DatabaseError::ValuesLenMismatch(expr_row.len(), values_len)); + } + + let mut row = Vec::with_capacity(values_len); + + for (col_index, expr) in expr_row.iter().enumerate() { + let mut expression = self.bind_expr(expr, arena)?; + ConstantCalculator::new(arena).visit(&mut expression)?; + + if let ScalarExpression::Constant(value) = expression { + let value_type = value.logical_type(); + + inferred_types[col_index] = match &inferred_types[col_index] { + Some(existing) => { + Some(LogicalType::max_logical_type(existing, &value_type)?.into_owned()) + } + None => Some(value_type), + }; + + row.push(value); + } else { + return Err(DatabaseError::ColumnsEmpty); + } + } + + rows.push(row); + } + + let value_name = arena.temp_table(); + let column_refs: Vec = inferred_types + .into_iter() + .enumerate() + .map(|(col_index, typ)| { + let typ = typ.ok_or(DatabaseError::InvalidType)?; + let mut column_ref = ColumnCatalog::new( + col_index.to_string(), + false, + ColumnDesc::new(typ, None, false, None)?, + ); + column_ref.set_ref_table(value_name.clone(), ColumnId::default(), true); + Ok(arena.alloc_column(column_ref)) + }) + .collect::>()?; + + Ok(self.bind_values(rows, column_refs)) + } + + fn bind_top_level_orderby( + &mut self, + mut plan: LogicalPlan, + orderbys: &[OrderByExpr], + arena: &mut PlanArena, + ) -> Result { + let saved_aliases = self.context.expr_aliases.clone(); + let output_schema = plan.output_schema(arena); + for (position, column) in output_schema.iter().enumerate() { + self.context.add_alias( + None, + arena.column(*column).name().to_string(), + ScalarExpression::column_expr(*column, position), + ); + } + + let sort_fields = self + .extract_having_orderby_aggregate_exprs(None, Some(orderbys), |binder, orderby| { + let OrderByExpr { expr, options, .. } = orderby; + Ok(SortField::new( + binder.bind_expr(expr, arena)?, + options.asc.is_none_or(|asc| asc), + options.nulls_first.unwrap_or(false), + )) + })? + .1; + self.context.expr_aliases = saved_aliases; + + Ok(match sort_fields { + Some(sort_fields) => self.bind_sort(plan, sort_fields, arena)?, + None => plan, + }) + } + + pub(crate) fn bind_where( + &mut self, + children: LogicalPlan, + predicate: &Expr, + arena: &mut PlanArena, + ) -> Result { + let predicate = with_query_bind_step!(self, QueryBindStep::Where, { + self.bind_expr(predicate, arena)? + })?; + + self.bind_where_expr(children, predicate, arena) + } + + pub(crate) fn normalize_select_item( + &mut self, + items: &[SelectItem], + arena: &mut PlanArena, + ) -> Result, DatabaseError> { + let mut select_items = vec![]; + + for item in items { + match item { + SelectItem::UnnamedExpr(expr) => select_items.push(self.bind_expr(expr, arena)?), + SelectItem::ExprWithAlias { expr, alias } => { + let expr = self.bind_expr(expr, arena)?; + let alias_name = lower_ident(alias).into_owned(); + + self.context + .add_alias(None, alias_name.clone(), expr.clone()); + + select_items.push(ScalarExpression::Alias { + expr: Box::new(expr), + alias: AliasType::Name(alias_name), + }); + } + SelectItem::Wildcard(_) => { + let visible_names = self + .context + .bind_table + .iter() + .filter(|bound_source| { + !Self::is_joined_values_source( + bound_source.join_type, + &bound_source.source, + arena, + ) + }) + .map(|bound_source| bound_source.visible_name()) + .unique() + .cloned() + .collect_vec(); + for visible_name in visible_names { + Self::bind_table_column_refs( + &self.context, + arena, + &mut select_items, + visible_name, + false, + )?; + } + } + SelectItem::QualifiedWildcard(table_name, _) => { + let table_name: TableName = match table_name { + SelectItemQualifiedWildcardKind::ObjectName(name) => { + lower_case_name(name)?.into() + } + SelectItemQualifiedWildcardKind::Expr(expr) => { + return Err(DatabaseError::UnsupportedStmt(format!( + "qualified wildcard expr: {expr}" + ))) + } + }; + Self::bind_table_column_refs( + &self.context, + arena, + &mut select_items, + table_name, + true, + )?; + } + }; + } + + Ok(select_items) + } + + pub(crate) fn bind_query( + &mut self, + query: &Query, + arena: &mut PlanArena, + ) -> Result { + let origin_step = self.context.step_now(); + + if let Some(_with) = &query.with { + // TODO support with clause. + } + + let order_by_exprs = if let Some(order_by) = &query.order_by { + match &order_by.kind { + OrderByKind::Expressions(exprs) => Some(exprs.as_slice()), + OrderByKind::All(_) => { + return Err(DatabaseError::UnsupportedStmt( + "ORDER BY ALL is not supported".to_string(), + )) + } + } + } else { + None + }; + let is_plain_select = matches!(query.body.as_ref(), SetExpr::Select(_)); + let mut plan = match query.body.as_ref() { + SetExpr::Select(select) => self.bind_select(select, order_by_exprs, arena), + SetExpr::Query(query) => self.bind_query(query, arena), + SetExpr::SetOperation { + op, + set_quantifier, + left, + right, + } => self.bind_set_operation(op, set_quantifier, left, right, arena), + SetExpr::Values(values) => self.bind_temp_values(&values.rows, arena), + expr => { + return Err(DatabaseError::UnsupportedStmt(format!( + "query body: {expr:?}" + ))) + } + }?; + + if !is_plain_select { + if let Some(order_by_exprs) = order_by_exprs { + plan = self.bind_top_level_orderby(plan, order_by_exprs, arena)?; + } + } + + if let Some(limit_clause) = &query.limit_clause { + plan = self.bind_limit(plan, limit_clause, arena)?; + } + + self.context.step(origin_step); + Ok(plan) + } + + fn bind_non_negative_limit_value( + &mut self, + expr: &Expr, + arena: &mut PlanArena, + ) -> Result { + let span = expr.span(); + let bound_expr = self.bind_expr(expr, arena)?; + match bound_expr { + ScalarExpression::Constant(dv) => match &dv { + DataValue::Int32(v) if *v >= 0 => Ok(*v as usize), + DataValue::Int64(v) if *v >= 0 => Ok(*v as usize), + _ => Err(DatabaseError::InvalidType), + }, + _ => Err(attach_span_if_absent( + DatabaseError::invalid_column("invalid limit expression.".to_owned()), + span, + )), + } + } + + fn bind_limit( + &mut self, + children: LogicalPlan, + limit: &LimitClause, + arena: &mut PlanArena, + ) -> Result { + let mut limit_value = None; + let mut offset_value = None; + match limit { + LimitClause::LimitOffset { + limit: limit_expr, + offset: offset_expr, + limit_by, + } => { + if !limit_by.is_empty() { + return Err(DatabaseError::UnsupportedStmt( + "LIMIT BY is not supported".to_string(), + )); + } + + if let Some(limit_ast) = limit_expr { + limit_value = Some(self.bind_non_negative_limit_value(limit_ast, arena)?); + } + + if let Some(offset_ast) = offset_expr { + offset_value = + Some(self.bind_non_negative_limit_value(&offset_ast.value, arena)?); + } + } + LimitClause::OffsetCommaLimit { + offset: offset_expr, + limit: limit_expr, + } => { + limit_value = Some(self.bind_non_negative_limit_value(limit_expr, arena)?); + offset_value = Some(self.bind_non_negative_limit_value(offset_expr, arena)?); + } + } + + self.bind_limit_values(children, offset_value, limit_value) + } + + fn build_statement<'s, 'arena>( + &'s mut self, + arena: &'s mut PlanArena<'arena>, + ) -> BindStatementStart<'s, 'a, 'parent, 'arena, T, A> { + BindStatementStart { + binder: self, + arena, + } + } + + pub fn bind( + &mut self, + stmt: &Statement, + arena: &mut PlanArena, + ) -> Result { + Ok(self.build_statement(arena).statement(stmt)?.finish()) + } +} diff --git a/src/binder/select.rs b/src/binder/select.rs index 40d7b21b..424e9088 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -24,21 +24,13 @@ use crate::{ }, types::value::DataValue, }; -use std::borrow::Borrow; use std::collections::HashSet; -use std::sync::Arc; -use super::{ - attach_span_if_absent, lower_case_name, lower_ident, Binder, BinderContext, QueryBindStep, - Source, SubQueryType, -}; +use super::{Binder, BinderContext, QueryBindStep, SetOperatorKind, Source, SubQueryType}; -use crate::catalog::{ - ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary, TableName, -}; +use crate::catalog::{ColumnRef, ColumnRelation, TableName}; use crate::errors::DatabaseError; use crate::execution::dql::join::joins_nullable; -use crate::expression::simplify::ConstantCalculator; use crate::expression::visitor_mut::{walk_mut_expr, PositionShift, VisitorMut}; use crate::expression::{AliasType, BinaryOperator}; use crate::planner::operator::function_scan::FunctionScanOperator; @@ -47,29 +39,29 @@ use crate::planner::operator::join::JoinCondition; use crate::planner::operator::set_membership::{SetMembershipKind, SetMembershipOperator}; use crate::planner::operator::sort::{SortField, SortOperator}; use crate::planner::operator::union::UnionOperator; -use crate::planner::{Childrens, LogicalPlan, SchemaOutput}; +use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; -use crate::types::tuple::{Schema, SchemaRef}; +use crate::types::tuple::Schema; use crate::types::{ColumnId, LogicalType}; use itertools::Itertools; -use sqlparser::ast::{ - Distinct, Expr, GroupByExpr, Join, JoinConstraint, JoinOperator, LimitClause, OrderByExpr, - OrderByKind, Query, Select, SelectInto, SelectItem, SelectItemQualifiedWildcardKind, SetExpr, - SetOperator, SetQuantifier, TableAlias, TableAliasColumnDef, TableFactor, TableWithJoins, -}; -struct RightSidePositionGlobalizer<'a> { +struct RightSidePositionGlobalizer<'a, 'p> { right_schema: &'a Schema, left_len: usize, + arena: &'a crate::planner::PlanArena<'p>, } -impl<'a> VisitorMut<'a> for RightSidePositionGlobalizer<'_> { +impl<'a> VisitorMut<'a> for RightSidePositionGlobalizer<'_, '_> { fn visit_column_ref( &mut self, column: &'a mut ColumnRef, position: &'a mut usize, ) -> Result<(), DatabaseError> { - if self.right_schema.contains(column) { + if self + .right_schema + .iter() + .any(|right| self.arena.same_column(*right, *column)) + { *position += self.left_len; } Ok(()) @@ -82,12 +74,13 @@ struct AppendedRightOutput { output_position: usize, } -struct SplitScopePositionRebinder<'a> { +struct SplitScopePositionRebinder<'a, 'p> { left_schema: &'a Schema, right_schema: &'a Schema, + arena: &'a crate::planner::PlanArena<'p>, } -impl VisitorMut<'_> for SplitScopePositionRebinder<'_> { +impl VisitorMut<'_> for SplitScopePositionRebinder<'_, '_> { fn visit_column_ref( &mut self, column: &mut ColumnRef, @@ -96,13 +89,13 @@ impl VisitorMut<'_> for SplitScopePositionRebinder<'_> { if let Some(left_position) = self .left_schema .iter() - .position(|candidate| candidate.same_column(column)) + .position(|candidate| self.arena.same_column(*candidate, *column)) { *position = left_position; } else if let Some(right_position) = self .right_schema .iter() - .position(|candidate| candidate.same_column(column)) + .position(|candidate| self.arena.same_column(*candidate, *column)) { *position = right_position; } @@ -110,47 +103,58 @@ impl VisitorMut<'_> for SplitScopePositionRebinder<'_> { } } -struct MarkerPositionGlobalizer<'a> { +struct MarkerPositionGlobalizer<'a, 'p> { output_column: &'a ColumnRef, left_len: usize, + arena: &'a crate::planner::PlanArena<'p>, } -impl VisitorMut<'_> for MarkerPositionGlobalizer<'_> { +impl VisitorMut<'_> for MarkerPositionGlobalizer<'_, '_> { fn visit_column_ref( &mut self, column: &mut ColumnRef, position: &mut usize, ) -> Result<(), DatabaseError> { - if column.same_column(self.output_column) { + if self.arena.same_column(*column, *self.output_column) { *position = self.left_len; } Ok(()) } } -struct ProjectionOutputBinder<'a> { +struct ProjectionOutputBinder<'a, 'p> { project_exprs: &'a [ScalarExpression], + arena: &'a mut crate::planner::PlanArena<'p>, } -impl<'a> ProjectionOutputBinder<'a> { - fn new(project_exprs: &'a [ScalarExpression]) -> Self { - Self { project_exprs } +impl<'a, 'p> ProjectionOutputBinder<'a, 'p> { + fn new( + project_exprs: &'a [ScalarExpression], + arena: &'a mut crate::planner::PlanArena<'p>, + ) -> Self { + Self { + project_exprs, + arena, + } } - fn output_ref(&self, expr: &ScalarExpression) -> Option { + fn output_ref(&mut self, expr: &ScalarExpression) -> Option { self.project_exprs .iter() .position(|candidate| { - candidate == expr || candidate.unpack_alias_ref() == expr.unpack_alias_ref() + candidate.eq_ignore_colref_pos(expr, self.arena) + || candidate + .unpack_alias_ref() + .eq_ignore_colref_pos(expr.unpack_alias_ref(), self.arena) }) .map(|position| { let output_expr = &self.project_exprs[position]; - ScalarExpression::column_expr(output_expr.output_column(), position) + ScalarExpression::column_expr(output_expr.output_column_ref(self.arena), position) }) } } -impl<'a> VisitorMut<'a> for ProjectionOutputBinder<'_> { +impl<'a> VisitorMut<'a> for ProjectionOutputBinder<'_, '_> { fn visit(&mut self, expr: &'a mut ScalarExpression) -> Result<(), DatabaseError> { if let Some(output_ref) = self.output_ref(expr) { *expr = output_ref; @@ -160,8 +164,538 @@ impl<'a> VisitorMut<'a> for ProjectionOutputBinder<'_> { } } +pub(crate) struct BindPlanStart<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) binder: &'s mut Binder<'a, 'b, T, A>, + pub(crate) arena: &'s mut crate::planner::PlanArena<'arena>, +} + +pub struct BindPlanFrom<'s, 'a, 'b, 'arena, T, A, M = ()> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) binder: &'s mut Binder<'a, 'b, T, A>, + pub(crate) arena: &'s mut crate::planner::PlanArena<'arena>, + pub(crate) plan: LogicalPlan, + pub(crate) _marker: std::marker::PhantomData, +} + +pub struct BindPlanSelectList<'s, 'a, 'b, 'arena, T, A, M = ()> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) binder: &'s mut Binder<'a, 'b, T, A>, + pub(crate) arena: &'s mut crate::planner::PlanArena<'arena>, + pub(super) plan: LogicalPlan, + pub(super) select_list: Vec, + pub(crate) _marker: std::marker::PhantomData, +} + +pub(crate) struct BindPlanFiltered<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(super) binder: &'s mut Binder<'a, 'b, T, A>, + pub(super) arena: &'s mut crate::planner::PlanArena<'arena>, + pub(super) plan: LogicalPlan, + pub(super) select_list: Vec, +} + +pub(crate) struct BindPlanAggregated<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + binder: &'s mut Binder<'a, 'b, T, A>, + arena: &'s mut crate::planner::PlanArena<'arena>, + plan: LogicalPlan, + select_list: Vec, + having: Option, + orderby: Option>, +} + +pub(crate) struct BindPlanHaving<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + binder: &'s mut Binder<'a, 'b, T, A>, + arena: &'s mut crate::planner::PlanArena<'arena>, + plan: LogicalPlan, + select_list: Vec, + orderby: Option>, +} + +pub(crate) struct BindPlanDistinct<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + binder: &'s mut Binder<'a, 'b, T, A>, + arena: &'s mut crate::planner::PlanArena<'arena>, + plan: LogicalPlan, + select_list: Vec, + orderby: Option>, +} + +pub(crate) struct BindPlanSorted<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + binder: &'s mut Binder<'a, 'b, T, A>, + arena: &'s mut crate::planner::PlanArena<'arena>, + plan: LogicalPlan, + select_list: Vec, +} + +pub(crate) struct BindPlanProjected<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + plan: LogicalPlan, + _marker: std::marker::PhantomData<(&'s (), &'a (), &'b (), &'arena (), T, A)>, +} + +pub(crate) struct BindPlanComplete { + plan: LogicalPlan, +} + +pub(crate) struct TableAliasInput { + pub(crate) name: TableName, + pub(crate) columns: Vec, +} + +pub(crate) enum JoinConstraintInput { + On(ScalarExpression), + Using(Vec), + Natural, + None, +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A, M> BindPlanFrom<'s, 'a, 'b, 'arena, T, A, M> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + #[cfg(feature = "orm")] + pub(crate) fn typed(self) -> BindPlanFrom<'s, 'a, 'b, 'arena, T, A, N> { + BindPlanFrom { + binder: self.binder, + arena: self.arena, + plan: self.plan, + _marker: std::marker::PhantomData, + } + } + + #[cfg(feature = "orm")] + pub(crate) fn filter_expr( + mut self, + predicate: ScalarExpression, + ) -> Result { + self.plan = self + .binder + .bind_where_expr(self.plan, predicate, self.arena)?; + Ok(self) + } + + #[cfg(feature = "orm")] + pub(crate) fn join_plan( + mut self, + right_plan: LogicalPlan, + right_context: BinderContext<'a, T>, + join_type: JoinType, + constraint: JoinConstraintInput, + ) -> Result { + self.binder.extend(right_context); + self.plan = self + .binder + .bind_join_plans(self.plan, right_plan, join_type, constraint, self.arena)?; + Ok(self) + } + + pub(crate) fn select_list( + self, + select_list: Vec, + ) -> BindPlanSelectList<'s, 'a, 'b, 'arena, T, A, M> { + BindPlanSelectList { + binder: self.binder, + arena: self.arena, + plan: self.plan, + select_list, + _marker: std::marker::PhantomData, + } + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A, M> BindPlanSelectList<'s, 'a, 'b, 'arena, T, A, M> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + #[cfg(feature = "orm")] + pub(crate) fn set_select_list(mut self, select_list: Vec) -> Self { + self.select_list = select_list; + self + } + + #[cfg(feature = "orm")] + pub(crate) fn group_by_expr(self, expr: ScalarExpression) -> Result { + let sorted = self + .filter_expr(None)? + .aggregate( + vec![expr], + None, + None::>, + |_binder, _arena, order| Ok(order), + )? + .having()? + .distinct(false)? + .order_by()?; + Ok(BindPlanSelectList { + binder: sorted.binder, + arena: sorted.arena, + plan: sorted.plan, + select_list: sorted.select_list, + _marker: std::marker::PhantomData, + }) + } + + #[cfg(feature = "orm")] + pub(crate) fn aggregate_without_group(self) -> Result { + let sorted = self + .filter_expr(None)? + .aggregate( + Vec::new(), + None, + None::>, + |_binder, _arena, order| Ok(order), + )? + .having()? + .distinct(false)? + .order_by()?; + Ok(BindPlanSelectList { + binder: sorted.binder, + arena: sorted.arena, + plan: sorted.plan, + select_list: sorted.select_list, + _marker: std::marker::PhantomData, + }) + } + + #[cfg(feature = "orm")] + pub(crate) fn having_expr(mut self, expr: ScalarExpression) -> Result { + self.plan = self.binder.bind_having(self.plan, expr, self.arena)?; + Ok(self) + } + + #[cfg(feature = "orm")] + pub(crate) fn sort_field(mut self, field: SortField) -> Result { + self.plan = self.binder.bind_sort(self.plan, vec![field], self.arena)?; + Ok(self) + } + + #[cfg(feature = "orm")] + pub fn distinct(mut self) -> Result { + let distinct_outputs = self.select_list.clone(); + self.binder.bind_distinct_output_exprs( + &distinct_outputs, + self.select_list.iter_mut(), + self.arena, + )?; + self.plan = self.binder.bind_distinct(self.plan, distinct_outputs)?; + Ok(self) + } + + #[cfg(feature = "orm")] + pub fn limit(mut self, limit: usize) -> Result { + self.plan = self + .binder + .bind_limit_values(self.plan, None, Some(limit))?; + Ok(self) + } + + #[cfg(feature = "orm")] + pub fn offset(mut self, offset: usize) -> Result { + self.plan = self + .binder + .bind_limit_values(self.plan, Some(offset), None)?; + Ok(self) + } + + #[cfg(feature = "orm")] + pub fn finish(self) -> Result { + if self.select_list.iter().any(ScalarExpression::has_agg_call) { + return self.aggregate_without_group()?.finish(); + } + self.binder + .bind_project(self.plan, self.select_list, self.arena) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanStart<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + #[allow(clippy::wrong_self_convention)] + pub(crate) fn from_plan( + self, + plan: LogicalPlan, + ) -> Result, DatabaseError> { + Ok(BindPlanFrom { + binder: self.binder, + arena: self.arena, + plan, + _marker: std::marker::PhantomData, + }) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A, M> BindPlanSelectList<'s, 'a, 'b, 'arena, T, A, M> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn filter_expr( + mut self, + predicate: Option, + ) -> Result, DatabaseError> { + if let Some(predicate) = predicate { + self.plan = self + .binder + .bind_where_expr(self.plan, predicate, self.arena)?; + } + + Ok(BindPlanFiltered { + binder: self.binder, + arena: self.arena, + plan: self.plan, + select_list: self.select_list, + }) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanFiltered<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn aggregate( + mut self, + group_by: Vec, + having: Option, + orderby: Option>, + mut bind_sort_field: impl FnMut( + &mut Binder<'a, 'b, T, A>, + &mut crate::planner::PlanArena<'arena>, + O, + ) -> Result, + ) -> Result, DatabaseError> { + self.binder + .extract_select_join(&mut self.select_list, self.arena); + self.binder + .extract_select_aggregate(&mut self.select_list)?; + + if !group_by.is_empty() { + self.binder + .extract_group_by_aggregate_exprs(&mut self.select_list, group_by)?; + } + + let mut having_orderby = (None, None); + if having.is_some() || orderby.is_some() { + having_orderby = self.binder.extract_having_orderby_aggregate_exprs( + having, + orderby, + |binder, orderby| bind_sort_field(binder, self.arena, orderby), + )?; + } + if !self.binder.context.agg_calls.is_empty() + || !self.binder.context.group_by_exprs.is_empty() + { + let agg_calls = std::mem::take(&mut self.binder.context.agg_calls); + let group_by_exprs = std::mem::take(&mut self.binder.context.group_by_exprs); + let output_exprs = self + .select_list + .iter_mut() + .chain(having_orderby.0.iter_mut()) + .chain( + having_orderby + .1 + .iter_mut() + .flat_map(|fields| fields.iter_mut().map(|field| &mut field.expr)), + ); + self.binder.bind_aggregate_output_exprs_with_outputs( + &agg_calls, + &group_by_exprs, + output_exprs, + self.arena, + )?; + self.plan = self + .binder + .bind_aggregate(self.plan, agg_calls, group_by_exprs)?; + } + + Ok(BindPlanAggregated { + binder: self.binder, + arena: self.arena, + plan: self.plan, + select_list: self.select_list, + having: having_orderby.0, + orderby: having_orderby.1, + }) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanAggregated<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn having( + mut self, + ) -> Result, DatabaseError> { + if let Some(having) = self.having { + self.plan = self.binder.bind_having(self.plan, having, self.arena)?; + } + + Ok(BindPlanHaving { + binder: self.binder, + arena: self.arena, + plan: self.plan, + select_list: self.select_list, + orderby: self.orderby, + }) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanHaving<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn distinct( + mut self, + distinct: bool, + ) -> Result, DatabaseError> { + if distinct { + let distinct_outputs = self.select_list.clone(); + self.binder.bind_distinct_output_exprs( + &distinct_outputs, + self.select_list.iter_mut(), + self.arena, + )?; + if let Some(orderby) = self.orderby.as_mut() { + self.binder + .bind_distinct_orderby_exprs(&distinct_outputs, orderby, self.arena)?; + } + self.plan = self.binder.bind_distinct(self.plan, distinct_outputs)?; + } + + Ok(BindPlanDistinct { + binder: self.binder, + arena: self.arena, + plan: self.plan, + select_list: self.select_list, + orderby: self.orderby, + }) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanDistinct<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn order_by( + mut self, + ) -> Result, DatabaseError> { + if let Some(orderby) = self.orderby { + self.plan = self.binder.bind_sort(self.plan, orderby, self.arena)?; + } + + Ok(BindPlanSorted { + binder: self.binder, + arena: self.arena, + plan: self.plan, + select_list: self.select_list, + }) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanSorted<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn project( + mut self, + ) -> Result, DatabaseError> { + if !self.select_list.is_empty() { + self.plan = self + .binder + .bind_project(self.plan, self.select_list, self.arena)?; + } + + Ok(BindPlanProjected { + plan: self.plan, + _marker: std::marker::PhantomData, + }) + } +} + +impl<'s, 'a: 'b, 'b, 'arena, T, A> BindPlanProjected<'s, 'a, 'b, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub(crate) fn insert_into( + mut self, + table_name: Option, + ) -> Result { + if let Some(table_name) = table_name { + self.plan = LogicalPlan::new( + Operator::Insert(InsertOperator { + table_name, + is_overwrite: false, + is_mapping_by_name: true, + }), + Childrens::Only(Box::new(self.plan)), + ) + } + + Ok(BindPlanComplete { plan: self.plan }) + } +} + +impl BindPlanComplete { + pub(crate) fn finish(self) -> LogicalPlan { + self.plan + } +} + impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'b, T, A> { - fn is_temp_alias_projection(exprs: &[ScalarExpression]) -> bool { + pub(crate) fn build_plan<'s, 'arena>( + &'s mut self, + arena: &'s mut crate::planner::PlanArena<'arena>, + ) -> BindPlanStart<'s, 'a, 'b, 'arena, T, A> { + BindPlanStart { + binder: self, + arena, + } + } + + fn is_temp_alias_projection( + exprs: &[ScalarExpression], + arena: &crate::planner::PlanArena, + ) -> bool { !exprs.is_empty() && exprs.iter().all(|expr| { matches!( @@ -173,7 +707,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' alias_expr.unpack_alias_ref(), ScalarExpression::ColumnRef { column, .. } if matches!( - &column.summary().relation, + &arena.column(*column).summary().relation, crate::catalog::ColumnRelation::Table { is_temp: true, .. } ) ) @@ -181,7 +715,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }) } - fn is_joined_values_source(join_type: Option, source: &Source<'a>) -> bool { + pub(crate) fn is_joined_values_source( + join_type: Option, + source: &Source<'a>, + arena: &crate::planner::PlanArena, + ) -> bool { join_type.is_some() && matches!( source, @@ -189,42 +727,25 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' if !schema_ref.is_empty() && schema_ref.iter().all(|column| { matches!( - &column.summary().relation, + &arena.column(*column).summary().relation, ColumnRelation::Table { is_temp: true, .. } - ) && column.id() == Some(ColumnId::default()) + ) && arena.column(*column).id() == Some(ColumnId::default()) }) ) } - fn bind_project_output_exprs<'c>( - project_exprs: &[ScalarExpression], - exprs: impl IntoIterator, - ) -> Result<(), DatabaseError> { - let mut binder = ProjectionOutputBinder::new(project_exprs); - for expr in exprs { - binder.visit(expr)?; - } - Ok(()) - } - - pub(crate) fn resolve_source_columns_in_scope( - context: &BinderContext<'a, T>, - table_schema_buf: &mut std::collections::HashMap>, + pub(crate) fn resolve_source_columns_in_scope<'context>( + context: &'context BinderContext<'a, T>, table_name: &str, - ) -> Result<(SchemaRef, usize), DatabaseError> { + ) -> Result<(&'context Source<'a>, usize), DatabaseError> { let mut position_offset = 0; for bound_source in &context.bind_table { - let schema_buf = table_schema_buf - .entry(bound_source.table_name.clone()) - .or_default(); - let schema_ref = bound_source.source.schema_ref(schema_buf); - if bound_source.matches_name(table_name) { - return Ok((schema_ref, position_offset)); + return Ok((&bound_source.source, position_offset)); } - position_offset += schema_ref.len(); + position_offset += bound_source.source.schema_len(); } Err(DatabaseError::invalid_table(table_name)) @@ -252,6 +773,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' join_condition: &mut JoinCondition, left_len: usize, right_schema: &Schema, + arena: &crate::planner::PlanArena, ) -> Result<(), DatabaseError> { let JoinCondition::On { filter, .. } = join_condition else { return Ok(()); @@ -261,6 +783,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' RightSidePositionGlobalizer { right_schema, left_len, + arena, } .visit(expr)?; } @@ -271,19 +794,22 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' fn localize_appended_right_outputs<'expr>( exprs: impl Iterator, appended_outputs: &[AppendedRightOutput], + arena: &crate::planner::PlanArena, ) -> Result<(), DatabaseError> { - struct AppendedRightOutputBinder<'a> { + struct AppendedRightOutputBinder<'a, 'p> { appended_outputs: &'a [AppendedRightOutput], + arena: &'a crate::planner::PlanArena<'p>, } - impl VisitorMut<'_> for AppendedRightOutputBinder<'_> { + impl VisitorMut<'_> for AppendedRightOutputBinder<'_, '_> { fn visit_column_ref( &mut self, column: &mut ColumnRef, position: &mut usize, ) -> Result<(), DatabaseError> { if let Some(output) = self.appended_outputs.iter().find(|output| { - *position == output.child_position && column.same_column(&output.column) + *position == output.child_position + && self.arena.same_column(*column, output.column) }) { *position = output.output_position; } @@ -291,7 +817,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } } - let mut binder = AppendedRightOutputBinder { appended_outputs }; + let mut binder = AppendedRightOutputBinder { + appended_outputs, + arena, + }; for expr in exprs { binder.visit(expr)?; } @@ -303,10 +832,12 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' expr: &mut ScalarExpression, left_schema: &Schema, right_schema: &Schema, + arena: &crate::planner::PlanArena, ) -> Result<(), DatabaseError> { SplitScopePositionRebinder { left_schema, right_schema, + arena, } .visit(expr) } @@ -317,9 +848,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' join_ty: JoinType, predicates: impl IntoIterator, rebind_positions: bool, + arena: &mut crate::planner::PlanArena, ) -> Result { - let left_schema = children.output_schema().clone(); - let right_schema = plan.output_schema().clone(); + let left_schema = children.output_schema(arena); + let right_schema = plan.output_schema(arena); let mut on_keys = Vec::new(); let mut filter = Vec::new(); @@ -327,16 +859,18 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' if rebind_positions { Self::rebind_split_scope_positions( &mut predicate, - left_schema.as_ref(), - right_schema.as_ref(), + left_schema, + right_schema, + arena, )?; } Self::extract_join_keys( predicate, &mut on_keys, &mut filter, - left_schema.as_ref(), - right_schema.as_ref(), + left_schema, + right_schema, + arena, )?; } @@ -347,7 +881,8 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' Self::globalize_join_filter_from_split_scope( &mut join_condition, left_schema.len(), - right_schema.as_ref(), + right_schema, + arena, )?; Ok(LJoinOperator::build( @@ -358,275 +893,42 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' )) } - pub(crate) fn bind_query(&mut self, query: &Query) -> Result { - let origin_step = self.context.step_now(); - - if let Some(_with) = &query.with { - // TODO support with clause. - } - - let order_by_exprs = if let Some(order_by) = &query.order_by { - match &order_by.kind { - OrderByKind::Expressions(exprs) => Some(exprs.as_slice()), - OrderByKind::All(_) => { - return Err(DatabaseError::UnsupportedStmt( - "ORDER BY ALL is not supported".to_string(), - )) - } - } - } else { - None - }; - let is_plain_select = matches!(query.body.borrow(), SetExpr::Select(_)); - let mut plan = match query.body.borrow() { - SetExpr::Select(select) => self.bind_select(select, order_by_exprs), - SetExpr::Query(query) => self.bind_query(query), - SetExpr::SetOperation { - op, - set_quantifier, - left, - right, - } => self.bind_set_operation(op, set_quantifier, left, right), - SetExpr::Values(values) => self.bind_temp_values(&values.rows), - expr => { - return Err(DatabaseError::UnsupportedStmt(format!( - "query body: {expr:?}" - ))) - } - }?; - - if !is_plain_select { - if let Some(order_by_exprs) = order_by_exprs { - plan = self.bind_top_level_orderby(plan, order_by_exprs)?; - } - } - - if let Some(limit_clause) = query.limit_clause.clone() { - plan = self.bind_limit(plan, limit_clause)?; - } - - self.context.step(origin_step); - Ok(plan) - } - - fn bind_top_level_orderby( - &mut self, - mut plan: LogicalPlan, - orderbys: &[OrderByExpr], - ) -> Result { - let saved_aliases = self.context.expr_aliases.clone(); - for (position, column) in plan.output_schema().iter().enumerate() { - self.context.add_alias( - None, - column.name().to_string(), - ScalarExpression::column_expr(column.clone(), position), - ); - } - - let sort_fields = self.extract_having_orderby_aggregate(&None, orderbys)?.1; - self.context.expr_aliases = saved_aliases; - - Ok(match sort_fields { - Some(sort_fields) => self.bind_sort(plan, sort_fields)?, - None => plan, - }) - } - - pub(crate) fn bind_select( - &mut self, - select: &Select, - orderby: Option<&[OrderByExpr]>, - ) -> Result { - let mut plan = if select.from.is_empty() { - LogicalPlan::new(Operator::Dummy, Childrens::None) - } else { - let mut plan = self.bind_table_ref(&select.from[0])?; - - if select.from.len() > 1 { - for from in select.from[1..].iter() { - plan = LJoinOperator::build( - plan, - self.bind_table_ref(from)?, - JoinCondition::None, - JoinType::Cross, - ) - } - } - plan - }; - let select_bind_step = self.context.step_now(); - self.context.step(QueryBindStep::Project); - let mut select_list = self.normalize_select_item(&select.projection)?; - self.context.step(select_bind_step); - - if let Some(predicate) = &select.selection { - plan = self.bind_where(plan, predicate)?; - } - self.extract_select_join(&mut select_list); - self.extract_select_aggregate(&mut select_list)?; - - match &select.group_by { - GroupByExpr::Expressions(group_by_exprs, modifiers) => { - if !modifiers.is_empty() { - return Err(DatabaseError::UnsupportedStmt( - "GROUP BY modifiers are not supported".to_string(), - )); - } - if !group_by_exprs.is_empty() { - self.extract_group_by_aggregate(&mut select_list, group_by_exprs)?; - } - } - GroupByExpr::All(_) => { - return Err(DatabaseError::UnsupportedStmt( - "GROUP BY ALL is not supported".to_string(), - )) - } - } - - let mut having_orderby = (None, None); - - if select.having.is_some() || orderby.is_some() { - having_orderby = - self.extract_having_orderby_aggregate(&select.having, orderby.unwrap_or(&[]))?; - } - - if !self.context.agg_calls.is_empty() || !self.context.group_by_exprs.is_empty() { - plan = self.bind_aggregate( - plan, - self.context.agg_calls.clone(), - self.context.group_by_exprs.clone(), - )?; - self.bind_aggregate_output_exprs(select_list.iter_mut())?; - if let Some(orderby) = having_orderby.1.as_mut() { - self.bind_aggregate_output_exprs(orderby.iter_mut().map(|field| &mut field.expr))?; - } - } - - if let Some(having) = having_orderby.0 { - plan = self.bind_having(plan, having)?; - } - - if let Some(Distinct::Distinct) = select.distinct { - plan = self.bind_distinct(plan, select_list.clone())?; - let distinct_outputs = select_list.clone(); - self.bind_distinct_output_exprs(&distinct_outputs, select_list.iter_mut())?; - if let Some(orderby) = having_orderby.1.as_mut() { - self.bind_distinct_orderby_exprs(&distinct_outputs, orderby)?; - } - } - - if let Some(orderby) = having_orderby.1 { - plan = self.bind_sort(plan, orderby)?; - } - - if !select_list.is_empty() { - plan = self.bind_project(plan, select_list)?; - } - - if let Some(SelectInto { name, .. }) = &select.into { - plan = LogicalPlan::new( - Operator::Insert(InsertOperator { - table_name: lower_case_name(name)?.into(), - is_overwrite: false, - is_mapping_by_name: true, - }), - Childrens::Only(Box::new(plan)), - ) - } - - Ok(plan) - } - - /// FIXME: temp values need to register BindContext.bind_table - fn bind_temp_values(&mut self, expr_rows: &[Vec]) -> Result { - let values_len = expr_rows[0].len(); - - let mut inferred_types: Vec> = vec![None; values_len]; - let mut rows = Vec::with_capacity(expr_rows.len()); - - for expr_row in expr_rows.iter() { - if expr_row.len() != values_len { - return Err(DatabaseError::ValuesLenMismatch(expr_row.len(), values_len)); - } - - let mut row = Vec::with_capacity(values_len); - - for (col_index, expr) in expr_row.iter().enumerate() { - let mut expression = self.bind_expr(expr)?; - ConstantCalculator.visit(&mut expression)?; - - if let ScalarExpression::Constant(value) = expression { - let value_type = value.logical_type(); - - inferred_types[col_index] = match &inferred_types[col_index] { - Some(existing) => { - Some(LogicalType::max_logical_type(existing, &value_type)?.into_owned()) - } - None => Some(value_type), - }; - - row.push(value); - } else { - return Err(DatabaseError::ColumnsEmpty); - } - } - - rows.push(row); - } - - let value_name = self.context.temp_table(); - let column_refs: Vec = inferred_types - .into_iter() - .enumerate() - .map(|(col_index, typ)| { - let typ = typ.ok_or(DatabaseError::InvalidType)?; - let mut column_ref = ColumnCatalog::new( - col_index.to_string(), - false, - ColumnDesc::new(typ, None, false, None)?, - ); - column_ref.set_ref_table(value_name.clone(), ColumnId::default(), true); - Ok(ColumnRef::from(column_ref)) - }) - .collect::>()?; - - Ok(self.bind_values(rows, Arc::new(column_refs))) - } - fn bind_set_cast( - &self, + &mut self, mut left_plan: LogicalPlan, mut right_plan: LogicalPlan, + arena: &mut crate::planner::PlanArena, ) -> Result<(LogicalPlan, LogicalPlan), DatabaseError> { let mut left_cast = vec![]; let mut right_cast = vec![]; - let left_schema = left_plan.output_schema(); - let right_schema = right_plan.output_schema(); + let left_schema = left_plan.output_schema(arena); + let right_schema = right_plan.output_schema(arena); for (position, (left_schema, right_schema)) in left_schema.iter().zip(right_schema.iter()).enumerate() { + let left_column = arena.column(*left_schema); + let right_column = arena.column(*right_schema); let cast_type = - LogicalType::max_logical_type(left_schema.datatype(), right_schema.datatype())?; - if cast_type.as_ref() != left_schema.datatype() { + LogicalType::max_logical_type(left_column.datatype(), right_column.datatype())?; + if cast_type.as_ref() != left_column.datatype() { left_cast.push(ScalarExpression::type_cast( - ScalarExpression::column_expr(left_schema.clone(), position), + ScalarExpression::column_expr(*left_schema, position), cast_type.clone(), + arena, )?); } else { - left_cast.push(ScalarExpression::column_expr(left_schema.clone(), position)); + left_cast.push(ScalarExpression::column_expr(*left_schema, position)); } - if cast_type.as_ref() != right_schema.datatype() { + if cast_type.as_ref() != right_column.datatype() { right_cast.push(ScalarExpression::type_cast( - ScalarExpression::column_expr(right_schema.clone(), position), + ScalarExpression::column_expr(*right_schema, position), cast_type.clone(), + arena, )?); } else { - right_cast.push(ScalarExpression::column_expr( - right_schema.clone(), - position, - )); + right_cast.push(ScalarExpression::column_expr(*right_schema, position)); } } @@ -647,61 +949,16 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' Ok((left_plan, right_plan)) } - pub(crate) fn bind_set_operation( + pub(crate) fn bind_set_operation_plans( &mut self, - op: &SetOperator, - set_quantifier: &SetQuantifier, - left: &SetExpr, - right: &SetExpr, + op: SetOperatorKind, + is_all: bool, + mut left_plan: LogicalPlan, + mut right_plan: LogicalPlan, + arena: &mut crate::planner::PlanArena, ) -> Result { - let is_all = match set_quantifier { - SetQuantifier::All => true, - SetQuantifier::Distinct | SetQuantifier::None => false, - SetQuantifier::ByName | SetQuantifier::AllByName | SetQuantifier::DistinctByName => { - return Err(DatabaseError::UnsupportedStmt( - "set quantifier BY NAME is not supported".to_string(), - )) - } - }; - let BinderContext { - table_cache, - view_cache, - transaction, - scala_functions, - table_functions, - temp_table_id, - .. - } = &self.context; - let mut left_binder = Binder::new( - BinderContext::new( - table_cache, - view_cache, - *transaction, - scala_functions, - table_functions, - temp_table_id.clone(), - ), - self.args, - Some(self), - ); - let mut right_binder = Binder::new( - BinderContext::new( - table_cache, - view_cache, - *transaction, - scala_functions, - table_functions, - temp_table_id.clone(), - ), - self.args, - Some(self), - ); - - let mut left_plan = left_binder.bind_set_expr(left)?; - let mut right_plan = right_binder.bind_set_expr(right)?; - - let mut left_schema = left_plan.output_schema(); - let mut right_schema = right_plan.output_schema(); + let mut left_schema = left_plan.output_schema(arena); + let mut right_schema = right_plan.output_schema(arena); let left_len = left_schema.len(); @@ -715,15 +972,15 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' if !left_schema .iter() .zip(right_schema.iter()) - .all(|(left, right)| left.datatype() == right.datatype()) + .all(|(left, right)| arena.column(*left).datatype() == arena.column(*right).datatype()) { - (left_plan, right_plan) = self.bind_set_cast(left_plan, right_plan)?; - left_schema = left_plan.output_schema(); - right_schema = right_plan.output_schema(); + (left_plan, right_plan) = self.bind_set_cast(left_plan, right_plan, arena)?; + left_schema = left_plan.output_schema(arena); + right_schema = right_plan.output_schema(arena); } match op { - SetOperator::Union => { + SetOperatorKind::Union => { if is_all { Ok(UnionOperator::build( left_schema.clone(), @@ -756,10 +1013,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' )?) } } - SetOperator::Except | SetOperator::Intersect => { + SetOperatorKind::Except | SetOperatorKind::Intersect => { let kind = match op { - SetOperator::Except => SetMembershipKind::Except, - SetOperator::Intersect => SetMembershipKind::Intersect, + SetOperatorKind::Except => SetMembershipKind::Except, + SetOperatorKind::Intersect => SetMembershipKind::Intersect, _ => unreachable!(), }; @@ -779,8 +1036,8 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' left_plan = self.bind_distinct(left_plan, left_distinct_exprs)?; right_plan = self.bind_distinct(right_plan, right_distinct_exprs)?; - left_schema = left_plan.output_schema(); - right_schema = right_plan.output_schema(); + left_schema = left_plan.output_schema(arena); + right_schema = right_plan.output_schema(arena); } Ok(SetMembershipOperator::build( @@ -791,194 +1048,50 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' right_plan, )) } - set_operator => Err(DatabaseError::UnsupportedStmt(format!( - "set operator: {set_operator:?}" - ))), } } - pub(crate) fn bind_table_ref( - &mut self, - from: &TableWithJoins, - ) -> Result { - self.context.step(QueryBindStep::From); - - let TableWithJoins { relation, joins } = from; - let mut plan = self.bind_single_table_ref(relation, None)?; - - for join in joins { - plan = self.bind_join(plan, join)?; - } - Ok(plan) - } - - fn bind_single_table_ref( - &mut self, - table: &TableFactor, - joint_type: Option, - ) -> Result { - let plan = match table { - TableFactor::Table { name, alias, .. } => { - let table_name = lower_case_name(name)?; - - self._bind_single_table_ref(joint_type, &table_name, alias.as_ref())? - } - TableFactor::Derived { - subquery, alias, .. - } => { - let BinderContext { - table_cache, - view_cache, - transaction, - scala_functions, - table_functions, - temp_table_id, - .. - } = &self.context; - let mut binder = Binder::new( - BinderContext::new( - table_cache, - view_cache, - *transaction, - scala_functions, - table_functions, - temp_table_id.clone(), - ), - self.args, - Some(self), - ); - let mut plan = binder.bind_query(subquery)?; - - if let Some(TableAlias { - name, - columns: alias_column, - .. - }) = alias - { - let source_name = self.context.temp_table(); - let table_alias: TableName = lower_ident(name).into(); - - plan = self.bind_alias( - plan, - alias_column, - table_alias.clone(), - source_name.clone(), - )?; - self.context.add_bound_source( - table_alias.clone(), - Some(table_alias), - joint_type, - Source::Schema(plan.output_schema().clone()), - ); - } else { - let passthrough_source = { - let output_schema = plan.output_schema().clone(); - let mut names = output_schema - .iter() - .filter_map(|column| column.table_name().cloned()); - let first = names.next(); - if first.is_some() && names.all(|name| Some(name) == first) { - first - } else { - None - } - }; - let needs_virtual_source = passthrough_source.is_none(); - let source_name = - passthrough_source.unwrap_or_else(|| self.context.temp_table()); - - if needs_virtual_source { - plan = self.bind_schema_source(plan, source_name.clone()); - } - self.context.add_bound_source( - source_name.clone(), - None, - joint_type, - Source::Schema(plan.output_schema().clone()), - ); - } - plan - } - TableFactor::TableFunction { expr, alias } => { - if let ScalarExpression::TableFunction(function) = self.bind_expr(expr)? { - let mut table_alias = None; - let table_name: TableName = function.summary().name.clone(); - let mut plan = FunctionScanOperator::build(function); - - if let Some(TableAlias { - name, - columns: alias_column, - .. - }) = alias - { - table_alias = Some(lower_ident(name).into()); - - plan = self.bind_alias( - plan, - alias_column, - table_alias.clone().unwrap(), - table_name.clone(), - )?; - } - - let source = Source::Schema(plan.output_schema().clone()); - self.context - .add_bound_source(table_name, table_alias, joint_type, source); - plan - } else { - unreachable!() - } - } - table => return Err(DatabaseError::UnsupportedStmt(format!("{table:#?}"))), - }; - - Ok(plan) - } - pub(crate) fn bind_alias( &mut self, mut plan: LogicalPlan, - alias_column: &[TableAliasColumnDef], + alias_column: &[String], table_alias: TableName, table_name: TableName, + arena: &mut crate::planner::PlanArena, ) -> Result { - let input_schema = plan.output_schema(); - if !alias_column.is_empty() && alias_column.len() != input_schema.len() { + let input_schema = plan.output_schema(arena); + let input_schema_len = input_schema.len(); + if !alias_column.is_empty() && alias_column.len() != input_schema_len { return Err(DatabaseError::MisMatch("alias", "columns")); } - let aliases_with_columns = if alias_column.is_empty() { - input_schema - .iter() - .cloned() - .map(|column| (column.name().to_string(), column)) - .collect_vec() - } else { - alias_column - .iter() - .map(|column| lower_ident(&column.name).into_owned()) - .zip(input_schema.iter().cloned()) - .collect_vec() - }; - let mut alias_exprs = Vec::with_capacity(aliases_with_columns.len()); + let mut alias_exprs = Vec::with_capacity(input_schema_len); - for (alias, column) in aliases_with_columns { - let mut alias_column = ColumnCatalog::clone(&column); + for (position, column) in input_schema.iter().copied().enumerate() { + let alias = if alias_column.is_empty() { + arena.column(column).name().to_string() + } else { + alias_column[position].clone() + }; + let (mut alias_column, column_id, is_temp) = { + let source_column = arena.column(column); + ( + source_column.clone(), + source_column.id().unwrap_or_default(), + matches!( + &source_column.summary().relation, + ColumnRelation::Table { is_temp: true, .. } + ), + ) + }; alias_column.set_name(alias.clone()); - let is_temp = matches!( - &column.summary().relation, - ColumnRelation::Table { is_temp: true, .. } - ); - alias_column.set_ref_table( - table_alias.clone(), - column.id().unwrap_or(ColumnId::new()), - is_temp, - ); + alias_column.set_ref_table(table_alias.clone(), column_id, is_temp); + let alias_column = arena.alloc_column(alias_column); let alias_column_expr = ScalarExpression::Alias { - expr: Box::new(ScalarExpression::column_expr(column, alias_exprs.len())), + expr: Box::new(ScalarExpression::column_expr(column, position)), alias: AliasType::Expr(Box::new(ScalarExpression::column_expr( - ColumnRef::from(alias_column), - alias_exprs.len(), + alias_column, + position, ))), }; self.context.add_alias( @@ -989,25 +1102,36 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' alias_exprs.push(alias_column_expr); } self.context.add_table_alias(table_alias, table_name); - self.bind_project(plan, alias_exprs) + self.bind_project(plan, alias_exprs, arena) } - fn bind_schema_source(&mut self, mut plan: LogicalPlan, source_name: TableName) -> LogicalPlan { - let input_schema = plan.output_schema(); - let mut source_exprs = Vec::with_capacity(input_schema.len()); - - for (position, column) in input_schema.iter().cloned().enumerate() { - let mut source_column = ColumnCatalog::clone(&column); - source_column.set_ref_table( - source_name.clone(), - column.id().unwrap_or(ColumnId::new()), - true, - ); + fn bind_schema_source( + &mut self, + mut plan: LogicalPlan, + source_name: TableName, + arena: &mut crate::planner::PlanArena, + ) -> LogicalPlan { + let input_schema = plan.output_schema(arena); + let input_schema_len = input_schema.len(); + let mut source_exprs = Vec::with_capacity(input_schema_len); + + for (position, column) in input_schema.iter().copied().enumerate() { + let source_column = { + let column_catalog = arena.column(column); + let mut source_column = column_catalog.clone(); + source_column.set_ref_table( + source_name.clone(), + column_catalog.id().unwrap_or_default(), + true, + ); + source_column + }; + let source_column = arena.alloc_column(source_column); source_exprs.push(ScalarExpression::Alias { expr: Box::new(ScalarExpression::column_expr(column, position)), alias: AliasType::Expr(Box::new(ScalarExpression::column_expr( - ColumnRef::from(source_column), + source_column, position, ))), }); @@ -1016,20 +1140,14 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' Self::build_project_plan(plan, source_exprs) } - pub(crate) fn _bind_single_table_ref( + pub(crate) fn bind_base_table_ref( &mut self, join_type: Option, - table: &str, - alias: Option<&TableAlias>, + table_name: TableName, + alias: Option, + arena: &mut crate::planner::PlanArena, ) -> Result { - let table_name = table.into(); - let mut table_alias: Option = None; - let mut alias_idents = None; - - if let Some(TableAlias { name, columns, .. }) = alias { - table_alias = Some(lower_ident(name).into()); - alias_idents = Some(columns); - } + let table_alias = alias.as_ref().map(|alias| alias.name.clone()); let with_pk = self.is_scan_with_pk(&table_name); let source = self @@ -1037,7 +1155,9 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' .source_and_bind(table_name.clone(), table_alias.as_ref(), join_type, false)? .ok_or(DatabaseError::SourceNotFound)?; let mut plan = match source { - Source::Table(table) => TableScanOperator::build(table_name.clone(), table, with_pk)?, + Source::Table(table) => { + TableScanOperator::build(table_name.clone(), table, with_pk, arena)? + } Source::View(view) => LogicalPlan::clone(&view.plan), Source::Schema(_) => { return Err(DatabaseError::UnsupportedStmt( @@ -1046,116 +1166,139 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } }; - if let (Some(idents), Some(alias_name)) = (alias_idents, table_alias) { - plan = self.bind_alias(plan, idents, alias_name.clone(), table_name.clone())?; + if let Some(alias) = alias { + plan = self.bind_alias( + plan, + &alias.columns, + alias.name.clone(), + table_name.clone(), + arena, + )?; + let output_schema = plan.output_schema(arena).clone(); self.context.add_bound_source( table_name, - Some(alias_name), + Some(alias.name), join_type, - Source::Schema(plan.output_schema().clone()), + Source::Schema(output_schema), ); } Ok(plan) } - /// Normalize select item. - /// - /// - Qualified name, e.g. `SELECT t.a FROM t` - /// - Qualified name with wildcard, e.g. `SELECT t.* FROM t,t1` - /// - Scalar expression or aggregate expression, e.g. `SELECT COUNT(*) + 1 AS count FROM t` - /// - fn normalize_select_item( + pub(crate) fn bind_derived_source( &mut self, - items: &[SelectItem], - ) -> Result, DatabaseError> { - let mut select_items = vec![]; - - for item in items.iter() { - match item { - SelectItem::UnnamedExpr(expr) => select_items.push(self.bind_expr(expr)?), - SelectItem::ExprWithAlias { expr, alias } => { - let expr = self.bind_expr(expr)?; - let alias_name = lower_ident(alias).into_owned(); - - self.context - .add_alias(None, alias_name.clone(), expr.clone()); - - select_items.push(ScalarExpression::Alias { - expr: Box::new(expr), - alias: AliasType::Name(alias_name), - }); - } - SelectItem::Wildcard(_) => { - for visible_name in self - .context - .bind_table - .iter() - .filter(|bound_source| { - !Self::is_joined_values_source( - bound_source.join_type, - &bound_source.source, - ) - }) - .map(|bound_source| bound_source.visible_name()) - .unique() - .cloned() - { - Self::bind_table_column_refs( - &self.context, - &mut self.table_schema_buf, - &mut select_items, - visible_name, - false, - )?; - } - } - SelectItem::QualifiedWildcard(table_name, _) => { - let table_name: TableName = match table_name { - SelectItemQualifiedWildcardKind::ObjectName(name) => { - lower_case_name(name)?.into() - } - SelectItemQualifiedWildcardKind::Expr(expr) => { - return Err(DatabaseError::UnsupportedStmt(format!( - "qualified wildcard expr: {expr}" - ))) - } - }; - Self::bind_table_column_refs( - &self.context, - &mut self.table_schema_buf, - &mut select_items, - table_name, - true, - )?; + mut plan: LogicalPlan, + alias: Option, + joint_type: Option, + arena: &mut crate::planner::PlanArena, + ) -> Result { + if let Some(alias) = alias { + let source_name = arena.temp_table(); + + plan = self.bind_alias( + plan, + &alias.columns, + alias.name.clone(), + source_name.clone(), + arena, + )?; + let output_schema = plan.output_schema(arena).clone(); + self.context.add_bound_source( + alias.name.clone(), + Some(alias.name), + joint_type, + Source::Schema(output_schema), + ); + } else { + let passthrough_source = { + let output_schema = plan.output_schema(arena); + let mut names = output_schema + .iter() + .filter_map(|column| arena.column(*column).table_name().cloned()); + let first = names.next(); + if first.is_some() && names.all(|name| Some(name) == first) { + first + } else { + None } }; + let needs_virtual_source = passthrough_source.is_none(); + let source_name = passthrough_source.unwrap_or_else(|| arena.temp_table()); + + if needs_virtual_source { + plan = self.bind_schema_source(plan, source_name.clone(), arena); + } + let output_schema = plan.output_schema(arena).clone(); + self.context.add_bound_source( + source_name.clone(), + None, + joint_type, + Source::Schema(output_schema), + ); } - Ok(select_items) + Ok(plan) } + pub(crate) fn bind_table_function_source( + &mut self, + expr: ScalarExpression, + alias: Option, + joint_type: Option, + arena: &mut crate::planner::PlanArena, + ) -> Result { + let ScalarExpression::TableFunction(function) = expr else { + return Err(DatabaseError::UnsupportedStmt( + "table function source must be a table function expression".to_string(), + )); + }; + + let mut table_alias = None; + let table_name: TableName = function.summary().name.clone(); + let mut plan = FunctionScanOperator::build(function); + + if let Some(alias) = alias { + table_alias = Some(alias.name.clone()); + + plan = self.bind_alias(plan, &alias.columns, alias.name, table_name.clone(), arena)?; + } + + let source = Source::Schema(plan.output_schema(arena).clone()); + self.context + .add_bound_source(table_name, table_alias, joint_type, source); + Ok(plan) + } + + /// Normalize select item. + /// + /// - Qualified name, e.g. `SELECT t.a FROM t` + /// - Qualified name with wildcard, e.g. `SELECT t.* FROM t,t1` + /// - Scalar expression or aggregate expression, e.g. `SELECT COUNT(*) + 1 AS count FROM t` + /// #[allow(unused_assignments)] - fn bind_table_column_refs( + pub(crate) fn bind_table_column_refs( context: &BinderContext<'a, T>, - table_schema_buf: &mut std::collections::HashMap>, + arena: &mut crate::planner::PlanArena, exprs: &mut Vec, table_name: TableName, is_qualified_wildcard: bool, ) -> Result<(), DatabaseError> { + let (source, position_offset) = + Self::resolve_source_columns_in_scope(context, table_name.as_ref())?; + let fn_not_on_using = |column: &ColumnRef| { + let column_catalog = arena.column(*column); if context.using.is_empty() { - return Some(&table_name) == column.table_name(); + return Some(&table_name) == column_catalog.table_name(); } is_qualified_wildcard - || Some(&table_name) == column.table_name() + || Some(&table_name) == column_catalog.table_name() && !context .using .values() - .any(|using_column| using_column.hides_column(column)) + .any(|using_column| using_column.hides_column(column, arena)) }; - let (schema_ref, position_offset) = - Self::resolve_source_columns_in_scope(context, table_schema_buf, table_name.as_ref())?; let mut pushed_alias_columns = false; for alias_column in context @@ -1166,10 +1309,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' .then_some(alias_column.as_str()) }) { - let Some((position, column)) = schema_ref + let Some((position, column)) = source + .schema() .iter() .enumerate() - .find(|(_, column)| column.name() == alias_column) + .find(|(_, column)| arena.column(**column).name() == alias_column) else { continue; }; @@ -1177,7 +1321,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' continue; } exprs.push(ScalarExpression::column_expr( - column.clone(), + *column, position_offset + position, )); pushed_alias_columns = true; @@ -1187,100 +1331,48 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' return Ok(()); } - for (position, column) in schema_ref.iter().enumerate() { + for (position, column) in source.schema().iter().enumerate() { if !fn_not_on_using(column) { continue; } exprs.push(ScalarExpression::column_expr( - column.clone(), + *column, position_offset + position, )); } Ok(()) } - fn bind_join( + pub(crate) fn bind_join_plans( &mut self, mut left: LogicalPlan, - join: &Join, + mut right: LogicalPlan, + join_type: JoinType, + constraint: JoinConstraintInput, + arena: &mut crate::planner::PlanArena, ) -> Result { - let Join { - relation, - join_operator, - .. - } = join; - - let (join_type, joint_condition) = match join_operator { - JoinOperator::Join(constraint) - | JoinOperator::Inner(constraint) - | JoinOperator::StraightJoin(constraint) => (JoinType::Inner, Some(constraint)), - JoinOperator::Left(constraint) | JoinOperator::LeftOuter(constraint) => { - (JoinType::LeftOuter, Some(constraint)) - } - JoinOperator::Right(constraint) | JoinOperator::RightOuter(constraint) => { - (JoinType::RightOuter, Some(constraint)) - } - JoinOperator::FullOuter(constraint) => (JoinType::Full, Some(constraint)), - JoinOperator::CrossJoin(constraint) => (JoinType::Cross, Some(constraint)), - JoinOperator::Semi(_) - | JoinOperator::LeftSemi(_) - | JoinOperator::Anti(_) - | JoinOperator::LeftAnti(_) - | JoinOperator::RightSemi(_) - | JoinOperator::RightAnti(_) - | JoinOperator::CrossApply - | JoinOperator::OuterApply - | JoinOperator::AsOf { .. } => { - return Err(DatabaseError::UnsupportedStmt(format!("{join_operator:?}"))) - } - }; - let BinderContext { - table_cache, - view_cache, - transaction, - scala_functions, - table_functions, - temp_table_id, - .. - } = &self.context; - let mut binder = Binder::new( - BinderContext::new( - table_cache, - view_cache, - *transaction, - scala_functions, - table_functions, - temp_table_id.clone(), - ), - self.args, - Some(self), - ); - let mut right = binder.bind_single_table_ref(relation, Some(join_type))?; - self.extend(binder.context); - - let mut on = match joint_condition { - Some(constraint) => self.bind_join_constraint( - join_type, - left.output_schema(), - right.output_schema(), - constraint, - )?, - None => JoinCondition::None, - }; - Self::localize_join_condition_from_join_scope(&mut on, left.output_schema().len())?; + let left_len = left.output_schema(arena).len(); + right.output_schema(arena); + let mut on = self.bind_join_constraint( + join_type, + constraint, + left.output_schema(arena), + right.output_schema(arena), + arena, + )?; + Self::localize_join_condition_from_join_scope(&mut on, left_len)?; Ok(LJoinOperator::build(left, right, on, join_type)) } - pub(crate) fn bind_where( + pub(crate) fn bind_where_expr( &mut self, mut children: LogicalPlan, - predicate: &Expr, + mut predicate: ScalarExpression, + arena: &mut crate::planner::PlanArena, ) -> Result { self.context.step(QueryBindStep::Where); - let mut predicate = self.bind_expr(predicate)?; - if let Some(sub_queries) = self.context.sub_queries_at_now() { let mut uses_mark_apply = None; for sub_query in sub_queries { @@ -1297,14 +1389,16 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' )); } uses_mark_apply = Some(true); + let left_schema = children.output_schema(arena).clone(); let (plan, predicates) = Self::prepare_mark_apply( &mut predicate, &output_column, - children.output_schema(), + left_schema.as_ref(), plan, correlated, false, Vec::new(), + arena, )?; children = MarkApplyOperator::build_exists( children, @@ -1332,14 +1426,16 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' quantified_predicate = Self::rewrite_correlated_quantified_predicate(quantified_predicate); } + let left_schema = children.output_schema(arena).clone(); let (plan, predicates) = Self::prepare_mark_apply( &mut predicate, &output_column, - children.output_schema(), + left_schema.as_ref(), plan, correlated, true, vec![quantified_predicate], + arena, )?; children = MarkApplyOperator::build_quantified( children, @@ -1369,13 +1465,14 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' JoinType::Inner, std::iter::once(predicate.clone()), true, + arena, )?; } } } if matches!(uses_mark_apply, Some(true)) { let passthrough_exprs = children - .output_schema() + .output_schema(arena) .iter() .cloned() .enumerate() @@ -1397,8 +1494,9 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' fn ensure_mark_apply_right_outputs( plan: &mut LogicalPlan, predicates: &[ScalarExpression], + arena: &mut crate::planner::PlanArena, ) -> Vec { - let output_schema = plan.output_schema().clone(); + let output_schema = plan.output_schema(arena).clone(); let output_len = output_schema.len(); if let LogicalPlan { operator: Operator::Project(op), @@ -1409,38 +1507,37 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' let Childrens::Only(child) = childrens.as_mut() else { return Vec::new(); }; + let child_schema = child.output_schema(arena); let mut appended_outputs = Vec::new(); op.exprs.extend( - child - .output_schema() + child_schema .iter() .enumerate() .filter(|(_, column)| { !output_schema.contains(column) && predicates.iter().any(|expr| { - expr.any_referenced_column(true, |candidate| { - candidate.same_column(column) + expr.any_referenced_column(arena, |arena, candidate| { + arena.same_column(*candidate, **column) }) }) }) .map(|(position, column)| { appended_outputs.push(AppendedRightOutput { - column: column.clone(), + column: *column, child_position: position, output_position: output_len + appended_outputs.len(), }); - ScalarExpression::column_expr(column.clone(), position) + ScalarExpression::column_expr(*column, position) }), ); - if !appended_outputs.is_empty() { - plan.reset_output_schema_cache(); - } + plan.reset_output_schema_cache(); return appended_outputs; } Vec::new() } + #[allow(clippy::too_many_arguments)] fn prepare_mark_apply( predicate: &mut ScalarExpression, output_column: &ColumnRef, @@ -1449,16 +1546,18 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' correlated: bool, preserve_projection: bool, mut apply_predicates: Vec, + arena: &mut crate::planner::PlanArena, ) -> Result<(LogicalPlan, Vec), DatabaseError> { let left_len = left_schema.len(); MarkerPositionGlobalizer { output_column, left_len, + arena, } .visit(predicate)?; let (mut plan, correlated_filters) = if correlated { - Self::prepare_correlated_subquery_plan(plan, left_schema, preserve_projection)? + Self::prepare_correlated_subquery_plan(plan, left_schema, preserve_projection, arena)? } else { (plan, Vec::new()) }; @@ -1466,19 +1565,21 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' if correlated { let appended_right_outputs = - Self::ensure_mark_apply_right_outputs(&mut plan, &apply_predicates); + Self::ensure_mark_apply_right_outputs(&mut plan, &apply_predicates, arena); if !appended_right_outputs.is_empty() { Self::localize_appended_right_outputs( apply_predicates.iter_mut(), &appended_right_outputs, + arena, )?; } } - let right_schema = plan.output_schema().clone(); + let right_schema = plan.output_schema(arena); for expr in apply_predicates.iter_mut() { RightSidePositionGlobalizer { - right_schema: right_schema.as_ref(), + right_schema, left_len, + arena, } .visit(expr)?; } @@ -1513,25 +1614,42 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } } - fn plan_has_correlated_refs(plan: &LogicalPlan, left_schema: &Schema) -> bool { - let contains = |column: &ColumnRef| left_schema.contains(column); - - if plan.operator.any_referenced_column(true, contains) { + fn plan_has_correlated_refs( + plan: &LogicalPlan, + left_schema: &Schema, + arena: &mut crate::planner::PlanArena, + ) -> bool { + if !plan + .operator + .visit_referenced_columns(arena, &mut |arena, column| { + !left_schema + .iter() + .any(|left| arena.same_column(*left, *column)) + }) + { return true; } match plan.childrens.as_ref() { - Childrens::Only(child) => Self::plan_has_correlated_refs(child, left_schema), + Childrens::Only(child) => Self::plan_has_correlated_refs(child, left_schema, arena), Childrens::Twins { left, right } => { - Self::plan_has_correlated_refs(left, left_schema) - || Self::plan_has_correlated_refs(right, left_schema) + Self::plan_has_correlated_refs(left, left_schema, arena) + || Self::plan_has_correlated_refs(right, left_schema, arena) } Childrens::None => false, } } - fn expr_has_correlated_refs(expr: &ScalarExpression, left_schema: &Schema) -> bool { - expr.any_referenced_column(true, |column| left_schema.contains(column)) + fn expr_has_correlated_refs( + expr: &ScalarExpression, + left_schema: &Schema, + arena: &mut crate::planner::PlanArena, + ) -> bool { + expr.any_referenced_column(arena, |arena, column| { + left_schema + .iter() + .any(|left| arena.same_column(*left, *column)) + }) } fn split_conjuncts(expr: ScalarExpression, exprs: &mut Vec) { @@ -1565,11 +1683,12 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' plan: LogicalPlan, left_schema: &Schema, preserve_projection: bool, + arena: &mut crate::planner::PlanArena, ) -> Result<(LogicalPlan, Vec), DatabaseError> { match plan.childrens.as_ref() { Childrens::Only(_) => {} Childrens::Twins { .. } => { - if Self::plan_has_correlated_refs(&plan, left_schema) { + if Self::plan_has_correlated_refs(&plan, left_schema, arena) { return Err(DatabaseError::UnsupportedStmt( "correlated EXISTS/NOT EXISTS does not support set or join subqueries" .to_string(), @@ -1590,12 +1709,13 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' child, left_schema, preserve_projection, + arena, )?; let mut local_filters = Vec::new(); let mut predicates = Vec::new(); Self::split_conjuncts(op.predicate, &mut predicates); for predicate in predicates { - if Self::expr_has_correlated_refs(&predicate, left_schema) { + if Self::expr_has_correlated_refs(&predicate, left_schema, arena) { correlated_filters.push(predicate); } else { local_filters.push(predicate); @@ -1618,12 +1738,16 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' child, left_schema, preserve_projection, + arena, )?; - if !preserve_projection || Self::is_temp_alias_projection(&op.exprs) { + if !preserve_projection || Self::is_temp_alias_projection(&op.exprs, arena) { Ok((child, correlated_filters)) } else { - Self::bind_project_output_exprs(&op.exprs, correlated_filters.iter_mut())?; + let mut binder = ProjectionOutputBinder::new(&op.exprs, arena); + for expr in correlated_filters.iter_mut() { + binder.visit(expr)?; + } Ok(( LogicalPlan::new(Operator::Project(op), Childrens::Only(Box::new(child))), correlated_filters, @@ -1648,9 +1772,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' childrens.pop_only(), left_schema, preserve_projection, + arena, ), plan => { - if Self::plan_has_correlated_refs(&plan, left_schema) { + if Self::plan_has_correlated_refs(&plan, left_schema, arena) { Err(DatabaseError::UnsupportedStmt( "correlated EXISTS/NOT EXISTS only supports filter-based subqueries" .to_string(), @@ -1666,11 +1791,12 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' &mut self, children: LogicalPlan, mut having: ScalarExpression, + arena: &mut crate::planner::PlanArena, ) -> Result { self.context.step(QueryBindStep::Having); self.validate_having_orderby(&having)?; - self.bind_aggregate_output_exprs(std::iter::once(&mut having))?; + self.bind_aggregate_output_exprs(std::iter::once(&mut having), arena)?; Ok(FilterOperator::build(having, children, true)) } @@ -1688,6 +1814,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' &mut self, mut children: LogicalPlan, mut select_list: Vec, + arena: &mut crate::planner::PlanArena, ) -> Result { self.context.step(QueryBindStep::Project); @@ -1709,12 +1836,13 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' )); } - let left_len = children.output_schema().len(); - let right_schema = plan.output_schema().clone(); + let left_len = children.output_schema(arena).len(); + let right_schema = plan.output_schema(arena); for expr in select_list.iter_mut() { RightSidePositionGlobalizer { - right_schema: right_schema.as_ref(), + right_schema, left_len, + arena, } .visit(expr)?; } @@ -1726,10 +1854,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' Ok(Self::build_project_plan(children, select_list)) } - fn bind_sort( + pub(crate) fn bind_sort( &mut self, children: LogicalPlan, sort_fields: Vec, + _arena: &mut crate::planner::PlanArena, ) -> Result { self.context.step(QueryBindStep::Sort); @@ -1742,63 +1871,22 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' )) } - fn bind_non_negative_limit_value(&mut self, expr: &Expr) -> Result { - let bound_expr = self.bind_expr(expr)?; - match bound_expr { - ScalarExpression::Constant(dv) => match &dv { - DataValue::Int32(v) if *v >= 0 => Ok(*v as usize), - DataValue::Int64(v) if *v >= 0 => Ok(*v as usize), - _ => Err(DatabaseError::InvalidType), - }, - _ => Err(attach_span_if_absent( - DatabaseError::invalid_column("invalid limit expression.".to_owned()), - expr, - )), - } - } - - fn bind_limit( + pub(crate) fn bind_limit_values( &mut self, children: LogicalPlan, - limit: LimitClause, + offset_value: Option, + limit_value: Option, ) -> Result { self.context.step(QueryBindStep::Limit); - let mut limit_value = None; - let mut offset_value = None; - match limit { - LimitClause::LimitOffset { - limit: limit_expr, - offset: offset_expr, - limit_by, - } => { - if !limit_by.is_empty() { - return Err(DatabaseError::UnsupportedStmt( - "LIMIT BY is not supported".to_string(), - )); - } - - if let Some(limit_ast) = limit_expr.as_ref() { - limit_value = Some(self.bind_non_negative_limit_value(limit_ast)?); - } - - if let Some(offset_ast) = offset_expr.as_ref() { - offset_value = Some(self.bind_non_negative_limit_value(&offset_ast.value)?); - } - } - LimitClause::OffsetCommaLimit { - offset: offset_expr, - limit: limit_expr, - } => { - limit_value = Some(self.bind_non_negative_limit_value(&limit_expr)?); - offset_value = Some(self.bind_non_negative_limit_value(&offset_expr)?); - } - } - Ok(LimitOperator::build(offset_value, limit_value, children)) } - pub fn extract_select_join(&mut self, select_items: &mut [ScalarExpression]) { + pub fn extract_select_join( + &mut self, + select_items: &mut [ScalarExpression], + arena: &mut crate::planner::PlanArena, + ) { if self.context.bind_table.len() < 2 { return; } @@ -1829,16 +1917,14 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' if let ScalarExpression::ColumnRef { column, .. } = column { let _ = table_force_nullable .iter() - .find(|(table_name, source, _)| { - let schema_buf = self - .table_schema_buf - .entry((*table_name).clone()) - .or_default(); - - source.column(column.name(), schema_buf).is_some() + .find(|(table_name, _source, _)| { + arena + .column(*column) + .table_name() + .is_some_and(|column_table| column_table == *table_name) }) .map(|(_, _, nullable)| { - if let Some(new_column) = column.nullable_for_join(*nullable) { + if let Some(new_column) = arena.nullable_for_join(*column, *nullable) { *column = new_column; } }); @@ -1846,20 +1932,20 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } } - fn bind_join_constraint<'c>( + fn bind_join_constraint( &mut self, join_type: JoinType, - left_schema: &'c SchemaRef, - right_schema: &'c SchemaRef, - constraint: &JoinConstraint, + constraint: JoinConstraintInput, + left_schema: &Schema, + right_schema: &Schema, + arena: &mut crate::planner::PlanArena, ) -> Result { match constraint { - JoinConstraint::On(expr) => { + JoinConstraintInput::On(expr) => { // left and right columns that match equi-join pattern let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![]; // expression that didn't match equi-join pattern let mut filter = vec![]; - let expr = self.bind_expr(expr)?; Self::extract_join_keys( expr, @@ -1867,6 +1953,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' &mut filter, left_schema, right_schema, + arena, )?; // combine multiple filter exprs into one BinaryExpr @@ -1884,32 +1971,31 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' filter: join_filter, }) } - JoinConstraint::Using(idents) => { + JoinConstraintInput::Using(names) => { fn find_column<'a>( schema: &'a Schema, name: &'a str, + arena: &crate::planner::PlanArena, ) -> Option<(usize, &'a ColumnRef)> { schema .iter() .enumerate() - .find(|(_, column)| column.name() == name) + .find(|(_, column)| arena.column(**column).name() == name) } let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = Vec::new(); - for ident in idents { - let name = lower_case_name(ident)?; + for name in names { let (Some((left_position, left_column)), Some((right_position, right_column))) = ( - find_column(left_schema, &name), - find_column(right_schema, &name), + find_column(left_schema, &name, arena), + find_column(right_schema, &name, arena), ) else { - return Err(attach_span_if_absent( - DatabaseError::invalid_column("not found column".to_string()), - ident, + return Err(DatabaseError::invalid_column( + "not found column".to_string(), )); }; self.context.add_using( - name.clone().into_owned(), + name.clone(), join_type, left_column, left_position, @@ -1917,9 +2003,9 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' left_schema.len() + right_position, )?; on_keys.push(( - ScalarExpression::column_expr(left_column.clone(), left_position), + ScalarExpression::column_expr(*left_column, left_position), ScalarExpression::column_expr( - right_column.clone(), + *right_column, left_schema.len() + right_position, ), )); @@ -1929,10 +2015,13 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' filter: None, }) } - JoinConstraint::None => Ok(JoinCondition::None), - JoinConstraint::Natural => { - let fn_names = |schema: &'c Schema| -> HashSet<&'c str> { - schema.iter().map(|column| column.name()).collect() + JoinConstraintInput::None => Ok(JoinCondition::None), + JoinConstraintInput::Natural => { + let fn_names = |schema: &Schema| -> HashSet { + schema + .iter() + .map(|column| arena.column(*column).name().to_string()) + .collect() }; let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = Vec::new(); @@ -1944,21 +2033,20 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' left_schema .iter() .enumerate() - .find(|(_, column)| column.name() == *name), + .find(|(_, column)| arena.column(**column).name() == name), right_schema .iter() .enumerate() - .find(|(_, column)| column.name() == *name), + .find(|(_, column)| arena.column(**column).name() == name), ) { - let left_expr = - ScalarExpression::column_expr(left_column.clone(), left_position); + let left_expr = ScalarExpression::column_expr(*left_column, left_position); let right_expr = ScalarExpression::column_expr( - right_column.clone(), + *right_column, left_schema.len() + right_position, ); self.context.add_using( - name.to_string(), + name.clone(), join_type, left_column, left_position, @@ -1993,14 +2081,17 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' accum_filter: &mut Vec, left_schema: &Schema, right_schema: &Schema, + arena: &crate::planner::PlanArena, ) -> Result<(), DatabaseError> { - let fn_contains = |schema: &Schema, summary: &ColumnSummary| { - schema.iter().any(|column| summary == column.summary()) + let fn_contains = |schema: &Schema, column: ColumnRef| { + let summary = arena.column(column).summary(); + schema + .iter() + .any(|candidate| arena.column(*candidate).summary() == summary) + }; + let fn_or_contains = |column: ColumnRef| { + fn_contains(left_schema, column) || fn_contains(right_schema, column) }; - let fn_or_contains = - |left_schema: &Schema, right_schema: &Schema, summary: &ColumnSummary| { - fn_contains(left_schema, summary) || fn_contains(right_schema, summary) - }; match expr.unpack_alias() { ScalarExpression::Binary { @@ -2019,17 +2110,13 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' ScalarExpression::ColumnRef { column: r, .. }, ) => { // reorder left and right joins keys to pattern: (left, right) - if fn_contains(left_schema, l.summary()) - && fn_contains(right_schema, r.summary()) - { + if fn_contains(left_schema, *l) && fn_contains(right_schema, *r) { accum.push((*left_expr, *right_expr)); - } else if fn_contains(left_schema, r.summary()) - && fn_contains(right_schema, l.summary()) + } else if fn_contains(left_schema, *r) + && fn_contains(right_schema, *l) { accum.push((*right_expr, *left_expr)); - } else if fn_or_contains(left_schema, right_schema, l.summary()) - || fn_or_contains(left_schema, right_schema, r.summary()) - { + } else if fn_or_contains(*l) || fn_or_contains(*r) { accum_filter.push(ScalarExpression::Binary { left_expr, right_expr, @@ -2041,7 +2128,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } (ScalarExpression::ColumnRef { column, .. }, _) | (_, ScalarExpression::ColumnRef { column, .. }) => { - if fn_or_contains(left_schema, right_schema, column.summary()) { + if fn_or_contains(*column) { accum_filter.push(ScalarExpression::Binary { left_expr, right_expr, @@ -2053,10 +2140,10 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } _other => { // example: baz > 1 - if left_expr.all_referenced_columns(true, |column| { - fn_or_contains(left_schema, right_schema, column.summary()) - }) && right_expr.all_referenced_columns(true, |column| { - fn_or_contains(left_schema, right_schema, column.summary()) + if left_expr.all_referenced_columns(arena, |_, column| { + fn_or_contains(*column) + }) && right_expr.all_referenced_columns(arena, |_, column| { + fn_or_contains(*column) }) { accum_filter.push(ScalarExpression::Binary { left_expr, @@ -2077,6 +2164,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' accum_filter, left_schema, right_schema, + arena, )?; Self::extract_join_keys( *right_expr, @@ -2084,6 +2172,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' accum_filter, left_schema, right_schema, + arena, )?; } BinaryOperator::Or => { @@ -2096,11 +2185,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' }); } _ => { - if left_expr.all_referenced_columns(true, |column| { - fn_or_contains(left_schema, right_schema, column.summary()) - }) && right_expr.all_referenced_columns(true, |column| { - fn_or_contains(left_schema, right_schema, column.summary()) - }) { + if left_expr + .all_referenced_columns(arena, |_, column| fn_or_contains(*column)) + && right_expr + .all_referenced_columns(arena, |_, column| fn_or_contains(*column)) + { accum_filter.push(ScalarExpression::Binary { left_expr, right_expr, @@ -2113,9 +2202,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' } } expr => { - if expr.all_referenced_columns(true, |column| { - fn_or_contains(left_schema, right_schema, column.summary()) - }) { + if expr.all_referenced_columns(arena, |_, column| fn_or_contains(*column)) { // example: baz > 1 accum_filter.push(expr); } @@ -2130,7 +2217,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' mod tests { use super::{ProjectionOutputBinder, RightSidePositionGlobalizer}; use crate::binder::test::build_t1_table; - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::errors::DatabaseError; use crate::expression::visitor_mut::VisitorMut; use crate::expression::{AliasType, ScalarExpression}; @@ -2139,12 +2226,12 @@ mod tests { MarkApplyKind, MarkApplyOperator, MarkApplyQuantifier, }; use crate::planner::operator::Operator; - use crate::planner::{Childrens, LogicalPlan}; + use crate::planner::{Childrens, LogicalPlan, PlanArena}; use crate::types::LogicalType; - fn test_column(name: &str, position: usize) -> ScalarExpression { + fn test_column(arena: &mut PlanArena, name: &str, position: usize) -> ScalarExpression { ScalarExpression::column_expr( - ColumnRef::from(ColumnCatalog::new( + arena.alloc_column(ColumnCatalog::new( name.to_string(), true, ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), @@ -2186,21 +2273,23 @@ mod tests { #[test] fn test_right_side_position_globalizer_only_shifts_right_columns() -> Result<(), DatabaseError> { - let left_column = ColumnRef::from(ColumnCatalog::new( + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); + let left_column = arena.alloc_column(ColumnCatalog::new( "left".to_string(), true, ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), )); - let right_column = ColumnRef::from(ColumnCatalog::new( + let right_column = arena.alloc_column(ColumnCatalog::new( "right".to_string(), true, ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), )); - let right_schema = vec![right_column.clone()]; + let right_schema = vec![right_column]; let mut expr = ScalarExpression::Binary { op: crate::expression::BinaryOperator::Eq, left_expr: Box::new(ScalarExpression::column_expr(left_column, 0)), - right_expr: Box::new(ScalarExpression::column_expr(right_column.clone(), 0)), + right_expr: Box::new(ScalarExpression::column_expr(right_column, 0)), evaluator: None, ty: LogicalType::Boolean, }; @@ -2208,6 +2297,7 @@ mod tests { RightSidePositionGlobalizer { right_schema: &right_schema, left_len: 2, + arena: &arena, } .visit(&mut expr)?; @@ -2240,21 +2330,23 @@ mod tests { #[test] fn test_projection_output_binder_rewrites_to_project_slot() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); let project_output = ScalarExpression::Alias { - expr: Box::new(test_column("c1", 0)), + expr: Box::new(test_column(&mut arena, "c1", 0)), alias: AliasType::Name("v".to_string()), }; let mut expr = ScalarExpression::Alias { - expr: Box::new(test_column("c1", 0)), + expr: Box::new(test_column(&mut arena, "c1", 0)), alias: AliasType::Name("v".to_string()), }; - ProjectionOutputBinder::new(std::slice::from_ref(&project_output)).visit(&mut expr)?; + ProjectionOutputBinder::new(std::slice::from_ref(&project_output), &mut arena) + .visit(&mut expr)?; - assert_eq!( - expr, - ScalarExpression::column_expr(project_output.output_column(), 0) - ); + let expected = + ScalarExpression::column_expr(project_output.output_column_ref(&mut arena), 0); + assert!(expr.eq_ignore_colref_pos(&expected, &arena)); Ok(()) } @@ -2423,7 +2515,9 @@ mod tests { else { panic!("expected join filter") }; - let left_len = left.output_schema_direct().columns().count(); + let mut arena = PlanArena::new(&table_states.table_arena); + let mut left_plan = left.as_ref().clone(); + let left_len = left_plan.output_schema(&mut arena).len(); let mut positions = Vec::new(); collect_column_positions(filter, &mut positions); diff --git a/src/binder/truncate.rs b/src/binder/truncate.rs index fb0ef677..cd81370f 100644 --- a/src/binder/truncate.rs +++ b/src/binder/truncate.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{lower_case_name, Binder}; +use crate::binder::Binder; use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::planner::operator::truncate::TruncateOperator; @@ -20,15 +20,12 @@ use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use sqlparser::ast::ObjectName; impl> Binder<'_, '_, T, A> { pub(crate) fn bind_truncate( &mut self, - name: &ObjectName, + table_name: TableName, ) -> Result { - let table_name: TableName = lower_case_name(name)?.into(); - Ok(LogicalPlan::new( Operator::Truncate(TruncateOperator { table_name }), Childrens::None, diff --git a/src/binder/update.rs b/src/binder/update.rs index 83d62f53..e532d110 100644 --- a/src/binder/update.rs +++ b/src/binder/update.rs @@ -12,162 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{ - attach_span_from_sqlparser_span_if_absent, attach_span_if_absent, lower_case_name, Binder, -}; -use crate::catalog::TableName; +use crate::binder::Binder; +use crate::catalog::{ColumnRef, TableName}; use crate::errors::DatabaseError; -use crate::expression::visitor_mut::VisitorMut; use crate::expression::ScalarExpression; -use crate::planner::operator::project::ProjectOperator; use crate::planner::operator::update::UpdateOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; use crate::storage::Transaction; use crate::types::value::DataValue; -use sqlparser::ast::{ - Assignment, AssignmentTarget, Expr, Ident, ObjectName, TableFactor, TableWithJoins, -}; -use std::borrow::Cow; -use std::slice; - -struct UpdateExprTargetRemapper<'a> { - target_schema: &'a [crate::catalog::ColumnRef], -} - -impl VisitorMut<'_> for UpdateExprTargetRemapper<'_> { - fn visit_column_ref( - &mut self, - column: &mut crate::catalog::ColumnRef, - position: &mut usize, - ) -> Result<(), DatabaseError> { - let Some(target_position) = self - .target_schema - .iter() - .position(|target_column| target_column.same_column(column)) - else { - return Err(DatabaseError::UnsupportedStmt( - "joined UPDATE SET expressions can only reference target table columns".to_string(), - )); - }; - *position = target_position; - Ok(()) - } -} impl> Binder<'_, '_, T, A> { - fn single_ident_from_object_name(name: &ObjectName) -> Result<&Ident, DatabaseError> { - if name.0.len() != 1 { - return Err(attach_span_if_absent( - DatabaseError::invalid_column(name.to_string()), - name, - )); - } - name.0[0].as_ident().ok_or_else(|| { - attach_span_if_absent(DatabaseError::invalid_column(name.to_string()), name) - }) - } - pub(crate) fn bind_update( &mut self, - to: &TableWithJoins, - selection: &Option, - assignments: &[Assignment], + table_name: TableName, + value_exprs: Vec<(ColumnRef, ScalarExpression)>, + input: LogicalPlan, ) -> Result { - // FIXME: Make it better to detect the current BindStep - self.context.allow_default = true; - if let TableFactor::Table { name, .. } = &to.relation { - let is_joined_update = !to.joins.is_empty(); - let table_name: TableName = lower_case_name(name)?.into(); - self.with_pk(table_name.clone()); - - let mut plan = self.bind_table_ref(to)?; - let (target_schema, target_offset) = Self::resolve_source_columns_in_scope( - &self.context, - &mut self.table_schema_buf, - &table_name, - )?; - - if let Some(predicate) = selection { - plan = self.bind_where(plan, predicate)?; - } - let mut value_exprs = Vec::with_capacity(assignments.len()); - - if assignments.is_empty() { - return Err(DatabaseError::ColumnsEmpty); - } - for Assignment { target, value } in assignments { - let expression = self.bind_expr(value)?; - let mut idents = vec![]; - match target { - AssignmentTarget::ColumnName(name) => { - idents.push(Self::single_ident_from_object_name(name)?); - } - AssignmentTarget::Tuple(_) => { - return Err(DatabaseError::UnsupportedStmt( - "UPDATE assignment tuple target is not supported".to_string(), - )) - } - } - - for ident in idents { - match self.bind_column_ref_from_identifiers( - slice::from_ref(ident), - Some(table_name.as_ref()), - )? { - ScalarExpression::ColumnRef { column, .. } => { - let mut expr = if matches!(expression, ScalarExpression::Empty) { - let default_value = column - .default_value()? - .ok_or(DatabaseError::DefaultNotExist)?; - ScalarExpression::Constant(default_value) - } else { - expression.clone() - }; - expr = ScalarExpression::type_cast( - expr, - Cow::Borrowed(column.datatype()), - )?; - if is_joined_update { - UpdateExprTargetRemapper { - target_schema: &target_schema, - } - .visit(&mut expr)?; - } - value_exprs.push((column, expr)); - } - _ => { - return Err(attach_span_from_sqlparser_span_if_absent( - DatabaseError::invalid_column(ident.to_string()), - ident.span, - )) - } - } - } - } - self.context.allow_default = false; - if is_joined_update { - let exprs = target_schema - .iter() - .enumerate() - .map(|(index, column)| { - ScalarExpression::column_expr(column.clone(), target_offset + index) - }) - .collect(); - plan = LogicalPlan::new( - Operator::Project(ProjectOperator { exprs }), - Childrens::Only(Box::new(plan)), - ); - } - Ok(LogicalPlan::new( - Operator::Update(UpdateOperator { - table_name, - value_exprs, - }), - Childrens::Only(Box::new(plan)), - )) - } else { - unreachable!("only table") - } + Ok(LogicalPlan::new( + Operator::Update(UpdateOperator { + table_name, + value_exprs, + }), + Childrens::Only(Box::new(input)), + )) } } diff --git a/src/catalog/column.rs b/src/catalog/column.rs index 356f0f06..139c6115 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -20,25 +20,12 @@ use crate::types::value::DataValue; use crate::types::CharLengthUnits; use crate::types::{ColumnId, LogicalType}; use kite_sql_serde_macros::ReferenceSerialization; +use std::fmt; use std::hash::Hash; -use std::ops::Deref; -use std::sync::Arc; -#[derive(Debug, Clone, Hash, Eq, PartialEq)] -pub struct ColumnRef(pub Arc); - -impl Deref for ColumnRef { - type Target = ColumnCatalog; - - fn deref(&self) -> &Self::Target { - self.0.as_ref() - } -} - -impl From for ColumnRef { - fn from(c: ColumnCatalog) -> Self { - ColumnRef(Arc::new(c)) - } +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +pub struct ColumnRef { + pos: usize, } #[derive(Debug, Clone, Hash, Eq, PartialEq, ReferenceSerialization)] @@ -49,7 +36,7 @@ pub struct ColumnCatalog { in_join: bool, } -#[derive(Debug, Clone, Hash, Eq, PartialEq)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] pub enum ColumnRelation { None, Table { @@ -59,25 +46,25 @@ pub enum ColumnRelation { }, } -#[derive(Debug, Clone, Hash, Eq, PartialEq, ReferenceSerialization)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Ord, PartialOrd, ReferenceSerialization)] pub struct ColumnSummary { pub name: String, pub relation: ColumnRelation, } impl ColumnRef { - pub(crate) fn same_column(&self, other: &ColumnRef) -> bool { - self.summary() == other.summary() + pub(crate) fn new(pos: usize) -> Self { + Self { pos } } - pub(crate) fn nullable_for_join(&self, nullable: bool) -> Option { - if self.nullable == nullable { - return None; - } - let mut temp = ColumnCatalog::clone(self); - temp.nullable = nullable; - temp.in_join = true; - Some(ColumnRef::from(temp)) + pub(crate) fn pos(self) -> usize { + self.pos + } +} + +impl fmt::Display for ColumnRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "#{}", self.pos) } } @@ -94,20 +81,6 @@ impl ColumnCatalog { } } - pub(crate) fn direct_new( - summary: ColumnSummary, - nullable: bool, - column_desc: ColumnDesc, - in_join: bool, - ) -> ColumnCatalog { - ColumnCatalog { - summary, - nullable, - desc: column_desc, - in_join, - } - } - pub(crate) fn new_dummy(column_name: String) -> ColumnCatalog { ColumnCatalog { summary: ColumnSummary { @@ -160,6 +133,13 @@ impl ColumnCatalog { } } + pub(crate) fn is_persistent_table_column(&self) -> bool { + matches!( + self.summary.relation, + ColumnRelation::Table { is_temp: false, .. } + ) + } + pub fn set_name(&mut self, name: String) { self.summary.name = name; } @@ -184,6 +164,10 @@ impl ColumnCatalog { self.nullable = nullable; } + pub(crate) fn set_in_join(&mut self, in_join: bool) { + self.in_join = in_join; + } + pub fn datatype(&self) -> &LogicalType { &self.desc.column_datatype } @@ -207,8 +191,9 @@ impl ColumnCatalog { #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { - use super::{ColumnCatalog, ColumnDesc, ColumnRef}; + use super::{ColumnCatalog, ColumnDesc}; use crate::errors::DatabaseError; + use crate::planner::PlanArena; use crate::types::LogicalType; #[test] @@ -223,17 +208,19 @@ mod tests { true, ColumnDesc::new(LogicalType::Bigint, None, false, None)?, ); - let left_ref = ColumnRef::from(left.clone()); - let right_ref = ColumnRef::from(right.clone()); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); + let left_ref = arena.alloc_column(left.clone()); + let right_ref = arena.alloc_column(right.clone()); assert_ne!(left_ref, right_ref); - assert!(left_ref.same_column(&right_ref)); + assert!(arena.same_column(left_ref, right_ref)); left.set_name("c2".to_string()); right.set_name("c3".to_string()); - let left_ref = ColumnRef::from(left); - let right_ref = ColumnRef::from(right); - assert!(!left_ref.same_column(&right_ref)); + let left_ref = arena.alloc_column(left); + let right_ref = arena.alloc_column(right); + assert!(!arena.same_column(left_ref, right_ref)); Ok(()) } } @@ -255,7 +242,9 @@ impl ColumnDesc { default: Option, ) -> Result { if let Some(expr) = &default { - if expr.has_table_ref_column() { + let table_arena = crate::planner::TableArenaCell::default(); + let plan_arena = crate::planner::PlanArena::new(&table_arena); + if expr.has_table_ref_column(&plan_arena) { return Err(DatabaseError::DefaultNotColumnRef); } } diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 7210907b..ab94712b 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -15,8 +15,9 @@ use crate::catalog::{ColumnCatalog, ColumnRef, ColumnRelation}; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; +use crate::planner::{MetaArena, PlanArena}; use crate::types::index::{IndexMeta, IndexMetaRef, IndexType}; -use crate::types::tuple::SchemaRef; +use crate::types::tuple::Schema; use crate::types::{ColumnId, LogicalType}; use itertools::Itertools; use kite_sql_serde_macros::ReferenceSerialization; @@ -26,7 +27,6 @@ use std::{slice, vec}; use ulid::Generator; pub type TableName = Arc; -pub type PrimaryKeyIndices = Arc>; #[derive(Debug, Clone, PartialEq)] pub struct TableCatalog { @@ -34,17 +34,17 @@ pub struct TableCatalog { /// Mapping from column names to column ids column_idxs: BTreeMap, columns: BTreeMap, + column_refs: Vec, pub(crate) indexes: Vec, - schema_ref: SchemaRef, primary_keys: Vec<(usize, ColumnRef)>, - primary_key_indices: PrimaryKeyIndices, + primary_key_indices: Vec, primary_key_type: LogicalType, } -pub(crate) struct DmlTableSnapshot { - pub(crate) schema_ref: SchemaRef, - pub(crate) primary_key_indices: PrimaryKeyIndices, +pub(crate) struct DmlTableSnapshot<'a> { + pub(crate) columns: Schema, + pub(crate) primary_key_indices: &'a [usize], pub(crate) columns_len: usize, pub(crate) index_metas: Vec<(IndexMetaRef, Vec)>, } @@ -60,15 +60,19 @@ impl TableCatalog { &self.name } - pub(crate) fn get_unique_index(&self, col_id: &ColumnId) -> Option<&IndexMetaRef> { - self.indexes - .iter() - .find(|meta| matches!(meta.ty, IndexType::Unique) && &meta.column_ids[0] == col_id) + pub(crate) fn get_unique_index( + &self, + col_id: &ColumnId, + arena: &impl MetaArena, + ) -> Option { + self.indexes.iter().copied().find(|meta| { + let meta = arena.index(*meta); + matches!(meta.ty, IndexType::Unique) && &meta.column_ids[0] == col_id + }) } - #[allow(dead_code)] - pub(crate) fn get_column_by_id(&self, id: &ColumnId) -> Option<&ColumnRef> { - self.columns.get(id).map(|i| &self.schema_ref[*i]) + pub(crate) fn get_column_by_id(&self, id: &ColumnId) -> Option { + self.columns.get(id).map(|i| self.column_refs[*i]) } #[cfg(all(test, not(target_arch = "wasm32")))] @@ -76,10 +80,10 @@ impl TableCatalog { self.column_idxs.get(name).map(|(id, _)| id) } - pub(crate) fn get_column_by_name(&self, name: &str) -> Option<&ColumnRef> { + pub(crate) fn get_column_by_name(&self, name: &str) -> Option { self.column_idxs .get(name) - .map(|(_, i)| &self.schema_ref[*i]) + .map(|(_, i)| self.column_refs[*i]) } #[allow(dead_code)] @@ -88,15 +92,15 @@ impl TableCatalog { } pub(crate) fn columns(&self) -> slice::Iter<'_, ColumnRef> { - self.schema_ref.iter() + self.column_refs.iter() } - pub(crate) fn indexes(&self) -> slice::Iter<'_, IndexMetaRef> { - self.indexes.iter() + pub(crate) fn column_ref(&self, index: usize) -> Option { + self.column_refs.get(index).copied() } - pub fn schema_ref(&self) -> &SchemaRef { - &self.schema_ref + pub(crate) fn indexes(&self) -> slice::Iter<'_, IndexMetaRef> { + self.indexes.iter() } pub(crate) fn columns_len(&self) -> usize { @@ -111,19 +115,28 @@ impl TableCatalog { &self.primary_key_type } - pub(crate) fn primary_keys_indices(&self) -> &PrimaryKeyIndices { + #[cfg(feature = "copy")] + pub(crate) fn primary_key_indices(&self) -> &[usize] { &self.primary_key_indices } - pub(crate) fn dml_snapshot(&self) -> Result { + pub(crate) fn dml_snapshot( + &self, + arena: &mut PlanArena, + ) -> Result, DatabaseError> { let index_metas = self .indexes() - .map(|index_meta| Ok((index_meta.clone(), index_meta.column_exprs(self)?))) + .map(|index_meta| { + Ok(( + *index_meta, + arena.index(*index_meta).column_exprs(self, arena)?, + )) + }) .collect::, DatabaseError>>()?; Ok(DmlTableSnapshot { - schema_ref: self.schema_ref.clone(), - primary_key_indices: self.primary_key_indices.clone(), + columns: self.column_refs.clone(), + primary_key_indices: &self.primary_key_indices, columns_len: self.columns_len(), index_metas, }) @@ -134,6 +147,7 @@ impl TableCatalog { &mut self, mut col: ColumnCatalog, generator: &mut Generator, + arena: &mut impl MetaArena, ) -> Result { if self.column_idxs.contains_key(col.name()) { return Err(DatabaseError::DuplicateColumn(col.name().to_string())); @@ -151,12 +165,10 @@ impl TableCatalog { }; self.column_idxs - .insert(col.name().to_string(), (col_id, self.schema_ref.len())); - self.columns.insert(col_id, self.schema_ref.len()); - - let mut schema = Vec::clone(&self.schema_ref); - schema.push(ColumnRef::from(col)); - self.schema_ref = Arc::new(schema); + .insert(col.name().to_string(), (col_id, self.column_refs.len())); + self.columns.insert(col_id, self.column_refs.len()); + let column_ref = arena.alloc_column(col); + self.column_refs.push(column_ref); Ok(col_id) } @@ -166,20 +178,26 @@ impl TableCatalog { name: String, column_ids: Vec, ty: IndexType, - ) -> Result<&IndexMeta, DatabaseError> { + arena: &mut impl MetaArena, + ) -> Result { for index in self.indexes.iter() { - if index.name == name { + if arena.index(*index).name == name { return Err(DatabaseError::DuplicateIndex(name)); } } - let index_id = self.indexes.last().map(|index| index.id + 1).unwrap_or(0); + let index_id = self + .indexes + .last() + .map(|index| arena.index(*index).id + 1) + .unwrap_or(0); let pk_ty = self.primary_key_type.clone(); let mut val_tys = Vec::with_capacity(column_ids.len()); for column_id in column_ids.iter() { let val_ty = self .get_column_by_id(column_id) + .map(|column| arena.column(column)) .ok_or_else(|| DatabaseError::column_not_found(column_id.to_string()))? .datatype() .clone(); @@ -200,13 +218,15 @@ impl TableCatalog { name, ty, }; - self.indexes.push(Arc::new(index)); - Ok(self.indexes.last().unwrap()) + let index_ref = arena.alloc_index(index); + self.indexes.push(index_ref); + Ok(index_ref) } pub fn new( name: TableName, columns: Vec, + arena: &mut impl MetaArena, ) -> Result { if columns.is_empty() { return Err(DatabaseError::ColumnsEmpty); @@ -215,8 +235,8 @@ impl TableCatalog { name, column_idxs: BTreeMap::new(), columns: BTreeMap::new(), + column_refs: vec![], indexes: vec![], - schema_ref: Arc::new(vec![]), primary_keys: vec![], primary_key_indices: Default::default(), primary_key_type: LogicalType::SqlNull, @@ -224,86 +244,121 @@ impl TableCatalog { let mut generator = Generator::new(); for col_catalog in columns.into_iter() { let _ = table_catalog - .add_column(col_catalog, &mut generator) + .add_column(col_catalog, &mut generator, arena) .unwrap(); } let (primary_keys, primary_key_indices) = - Self::build_primary_keys(&table_catalog.schema_ref); + Self::build_primary_keys(&table_catalog.column_refs, arena); - table_catalog.primary_key_type = Self::build_primary_key_type(&primary_keys); + table_catalog.primary_key_type = Self::build_primary_key_type(&primary_keys, arena); table_catalog.primary_keys = primary_keys; table_catalog.primary_key_indices = primary_key_indices; Ok(table_catalog) } - fn build_primary_key_type(primary_keys: &[(usize, ColumnRef)]) -> LogicalType { + fn build_primary_key_type( + primary_keys: &[(usize, ColumnRef)], + arena: &impl MetaArena, + ) -> LogicalType { if primary_keys.len() == 1 { - primary_keys[0].1.datatype().clone() + arena.column(primary_keys[0].1).datatype().clone() } else { LogicalType::Tuple( primary_keys .iter() - .map(|(_, column)| column.datatype().clone()) + .map(|(_, column)| arena.column(*column).datatype().clone()) .collect_vec(), ) } } - pub(crate) fn reload( + pub(crate) fn reload( name: TableName, - column_refs: Vec, - indexes: Vec, - ) -> Result { + column_catalogs: I, + indexes: I2, + arena: &mut impl MetaArena, + ) -> Result + where + I: Iterator, + I2: Iterator, + { + let (lower_bound, _) = column_catalogs.size_hint(); let mut column_idxs = BTreeMap::new(); let mut columns = BTreeMap::new(); + let mut column_refs = Vec::with_capacity(lower_bound); - for (i, column_ref) in column_refs.iter().enumerate() { - let column_id = column_ref.id().ok_or(DatabaseError::invalid_column( + for (i, column_catalog) in column_catalogs.enumerate() { + let column_id = column_catalog.id().ok_or(DatabaseError::invalid_column( "column does not belong to table".to_string(), ))?; - column_idxs.insert(column_ref.name().to_string(), (column_id, i)); + column_idxs.insert(column_catalog.name().to_string(), (column_id, i)); columns.insert(column_id, i); + column_refs.push(arena.alloc_column(column_catalog)); } - let schema_ref = Arc::new(column_refs.clone()); - let (primary_keys, primary_key_indices) = Self::build_primary_keys(&schema_ref); - let primary_key_type = Self::build_primary_key_type(&primary_keys); + let indexes = indexes.map(|index| arena.alloc_index(index)).collect(); + let (primary_keys, primary_key_indices) = Self::build_primary_keys(&column_refs, arena); + let primary_key_type = Self::build_primary_key_type(&primary_keys, arena); Ok(TableCatalog { name, column_idxs, columns, + column_refs, indexes, - schema_ref, primary_keys, primary_key_indices, primary_key_type, }) } + pub(crate) fn transplant_to_table_arena( + &self, + source_arena: &PlanArena, + ) -> Result { + let column_catalogs = self + .columns() + .map(|column| source_arena.column(*column).clone()) + .collect_vec(); + let index_metas = self + .indexes() + .map(|index| source_arena.index(*index).clone()) + .collect_vec(); + + Self::reload( + self.name.clone(), + column_catalogs.into_iter(), + index_metas.into_iter(), + source_arena.table_arena_cell().borrow_mut(), + ) + } + fn build_primary_keys( - schema_ref: &Arc>, - ) -> (Vec<(usize, ColumnRef)>, PrimaryKeyIndices) { + columns: &[ColumnRef], + arena: &impl MetaArena, + ) -> (Vec<(usize, ColumnRef)>, Vec) { let mut primary_keys = Vec::new(); let mut primary_key_indices = Vec::new(); - for (_, (i, column)) in schema_ref + for (i, column) in columns .iter() .enumerate() .filter_map(|(i, column)| { - column + arena + .column(*column) .desc() .primary() - .map(|p_i| (p_i, (i, column.clone()))) + .map(|primary_index| (primary_index, i)) }) - .sorted_by_key(|(p_i, _)| *p_i) + .sorted_by_key(|(primary_index, _)| *primary_index) + .map(|(_, i)| (i, columns[i])) { primary_key_indices.push(i); primary_keys.push((i, column)); } - (primary_keys, Arc::new(primary_key_indices)) + (primary_keys, primary_key_indices) } } @@ -317,9 +372,20 @@ impl TableMeta { mod tests { use super::*; use crate::catalog::ColumnDesc; + use crate::planner::TableArenaCell; use crate::types::LogicalType; use ulid::Generator; + fn build_table_catalog( + name: &str, + columns: Vec, + ) -> (TableArenaCell, TableCatalog) { + let table_arena = TableArenaCell::default(); + let table_catalog = + TableCatalog::new(name.to_string().into(), columns, table_arena.borrow_mut()).unwrap(); + (table_arena, table_catalog) + } + #[test] // | a (Int32) | b (Bool) | // |-----------|----------| @@ -337,7 +403,7 @@ mod tests { ColumnDesc::new(LogicalType::Boolean, None, false, None).unwrap(), ); let col_catalogs = vec![col0, col1]; - let table_catalog = TableCatalog::new("test".to_string().into(), col_catalogs).unwrap(); + let (table_arena, table_catalog) = build_table_catalog("test", col_catalogs); assert!(table_catalog.contains_column("a")); assert!(table_catalog.contains_column("b")); @@ -347,11 +413,15 @@ mod tests { let col_b_id = table_catalog.get_column_id_by_name("b").unwrap(); assert!(col_a_id < col_b_id); - let column_catalog = table_catalog.get_column_by_id(col_a_id).unwrap(); + let column_catalog = table_arena + .borrow() + .column(table_catalog.get_column_by_id(col_a_id).unwrap()); assert_eq!(column_catalog.name(), "a"); assert_eq!(*column_catalog.datatype(), LogicalType::Integer,); - let column_catalog = table_catalog.get_column_by_id(col_b_id).unwrap(); + let column_catalog = table_arena + .borrow() + .column(table_catalog.get_column_by_id(col_b_id).unwrap()); assert_eq!(column_catalog.name(), "b"); assert_eq!(*column_catalog.datatype(), LogicalType::Boolean,); } @@ -359,8 +429,8 @@ mod tests { #[test] fn test_add_column_generates_id_after_existing_columns() { for _ in 0..256 { - let mut table_catalog = TableCatalog::new( - "test".to_string().into(), + let (table_arena, mut table_catalog) = build_table_catalog( + "test", vec![ ColumnCatalog::new( "id".into(), @@ -379,11 +449,10 @@ mod tests { .unwrap(), ), ], - ) - .unwrap(); + ); let max_existing_id = table_catalog .columns() - .filter_map(|column| column.id()) + .filter_map(|column| table_arena.borrow().column(*column).id()) .max() .unwrap(); let mut generator = Generator::new(); @@ -395,6 +464,7 @@ mod tests { ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), ), &mut generator, + table_arena.borrow_mut(), ) .unwrap(); diff --git a/src/catalog/view.rs b/src/catalog/view.rs index 3663a679..84c23251 100644 --- a/src/catalog/view.rs +++ b/src/catalog/view.rs @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::TableName; -use crate::planner::LogicalPlan; +use crate::catalog::{ColumnRef, TableName}; +use crate::planner::{LogicalPlan, MetaArena}; +use crate::types::tuple::Schema; use kite_sql_serde_macros::ReferenceSerialization; use std::fmt; use std::fmt::Formatter; @@ -22,11 +23,25 @@ use std::fmt::Formatter; pub struct View { pub name: TableName, pub plan: Box, + pub schema: Schema, +} + +impl View { + pub(crate) fn visit_column_refs(&self, arena: &mut A, f: &mut F) + where + A: MetaArena, + F: FnMut(&ColumnRef) + ?Sized, + { + for column in &self.schema { + f(column); + } + self.plan.visit_column_refs(arena, f); + } } impl fmt::Display for View { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "View {}: {}", self.name, self.plan.explain(0))?; + write!(f, "View {}", self.name)?; Ok(()) } diff --git a/src/db.rs b/src/db.rs index 395bebd1..b0ff24d4 100644 --- a/src/db.rs +++ b/src/db.rs @@ -12,103 +12,83 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::binder::{command_type, Binder, BinderContext, CommandType}; +#[cfg(feature = "parser")] +pub use crate::binder::{prepare, prepare_all, Statement}; +use crate::binder::{Binder, BinderContext}; +use crate::catalog::TableName; use crate::errors::DatabaseError; -use crate::execution::{build_write, ExecArena, Executor}; +use crate::execution::{build_write, DDLApply, ExecArena, ExecutionContext, Executor}; use crate::expression::function::scala::ScalarFunctionImpl; -use crate::expression::function::table::TableFunctionImpl; +use crate::expression::function::table::{ + ArcTableFunctionImpl, TableFunctionCatalog, TableFunctionImpl, +}; use crate::expression::function::FunctionSummary; use crate::function::char_length::CharLength; +#[cfg(feature = "time")] use crate::function::current_date::CurrentDate; +#[cfg(feature = "time")] use crate::function::current_timestamp::CurrentTimeStamp; use crate::function::lower::Lower; use crate::function::numbers::Numbers; use crate::function::octet_length::OctetLength; use crate::function::upper::Upper; +use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; use crate::optimizer::rule::implementation::ImplementationRuleImpl; use crate::optimizer::rule::normalization::NormalizationRuleImpl; -use crate::parser::parse_sql; use crate::planner::operator::Operator; -use crate::planner::LogicalPlan; +use crate::planner::{LogicalPlan, PlanArena, TableArenaCell}; #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] use crate::storage::lmdb::{LmdbConfig, LmdbStorage}; use crate::storage::memory::MemoryStorage; #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] use crate::storage::rocksdb::{OptimisticRocksStorage, RocksStorage, StorageConfig}; +use crate::storage::table_codec::TableCodec; use crate::storage::{ CheckpointableStorage, StatisticsMetaCache, Storage, TableCache, Transaction, TransactionIsolationLevel, ViewCache, }; -use crate::types::tuple::{SchemaRef, Tuple}; +use crate::types::tuple::{Schema, SchemaView, Tuple}; use crate::types::value::DataValue; -use crate::utils::lru::SharedLruCache; use ahash::HashMap; -use parking_lot::lock_api::{ArcRwLockReadGuard, ArcRwLockWriteGuard}; -use parking_lot::{RawRwLock, RwLock}; -use std::hash::RandomState; +use std::collections::HashSet; use std::marker::PhantomData; use std::mem; use std::path::Path; use std::path::PathBuf; -use std::sync::atomic::AtomicUsize; use std::sync::Arc; pub(crate) type ScalaFunctions = HashMap>; -pub(crate) type TableFunctions = HashMap>; - -/// Parsed SQL statement type used by KiteSQL execution APIs. -/// -/// This is a type alias for `sqlparser::ast::Statement`. In most cases you do -/// not need to construct it manually; use [`prepare`] or [`prepare_all`] to -/// parse SQL text into statements. -pub type Statement = sqlparser::ast::Statement; +pub(crate) type TableFunctions = HashMap; -/// Parses a single SQL statement into a reusable [`Statement`]. -/// -/// This is useful when you want to parse once and execute the same statement -/// multiple times with different parameters. If the input contains multiple -/// statements, only the last one is returned. -/// -/// # Examples -/// -/// ```rust -/// use kite_sql::db::prepare; -/// -/// let statement = prepare("select * from users where id = $1").unwrap(); -/// println!("{statement:?}"); -/// ``` -pub fn prepare>(sql: T) -> Result { - let mut stmts = prepare_all(sql)?; - stmts.pop().ok_or(DatabaseError::EmptyStatement) +pub enum CatalogKind { + Table(crate::catalog::TableName), + View(crate::catalog::TableName), + ScalarFunction(Arc), + TableFunction(Arc), } -/// Parses one or more SQL statements into a vector of [`Statement`] values. -/// -/// Returns [`DatabaseError::EmptyStatement`] when the input is empty or only -/// contains whitespace. -/// -/// # Examples -/// -/// ```rust -/// use kite_sql::db::prepare_all; -/// -/// let statements = prepare_all("select 1; select 2;").unwrap(); -/// assert_eq!(statements.len(), 2); -/// ``` -pub fn prepare_all>(sql: T) -> Result, DatabaseError> { - let stmts = parse_sql(sql)?; - if stmts.is_empty() { - return Err(DatabaseError::EmptyStatement); - } - Ok(stmts) -} +pub(crate) trait BindSource { + type Iter: ResultIter; + type Transaction: Transaction; + + fn execute(self, params: A, build: F) -> Result + where + A: AsRef<[(&'static str, DataValue)]>, + F: for<'bind> FnOnce( + &mut Binder<'bind, '_, Self::Transaction, A>, + &mut PlanArena<'_>, + ) -> Result; -#[allow(dead_code)] -pub(crate) enum MetaDataLock { - Read(ArcRwLockReadGuard), - Write(ArcRwLockWriteGuard), + #[cfg(feature = "orm")] + fn explain(self, params: A, build: F) -> Result + where + A: AsRef<[(&'static str, DataValue)]>, + F: for<'bind> FnOnce( + &mut Binder<'bind, '_, Self::Transaction, A>, + &mut PlanArena<'_>, + ) -> Result; } /// Builder for creating a [`Database`] instance. @@ -116,10 +96,11 @@ pub(crate) enum MetaDataLock { /// The builder wires together storage, built-in functions and optional runtime /// features before the database is opened. pub struct DataBaseBuilder { - #[cfg_attr(target_arch = "wasm32", allow(dead_code))] + #[cfg(all( + not(target_arch = "wasm32"), + any(feature = "rocksdb", feature = "lmdb") + ))] path: PathBuf, - scala_functions: ScalaFunctions, - table_functions: TableFunctions, histogram_buckets: Option, transaction_isolation: Option, #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] @@ -134,27 +115,30 @@ impl DataBaseBuilder { /// Built-in scalar functions and table functions are registered /// automatically. pub fn path(path: impl Into + Send) -> Self { - let mut builder = DataBaseBuilder { - path: path.into(), - scala_functions: Default::default(), - table_functions: Default::default(), + #[cfg(all( + not(target_arch = "wasm32"), + any(feature = "rocksdb", feature = "lmdb") + ))] + let path = path.into(); + #[cfg(not(all( + not(target_arch = "wasm32"), + any(feature = "rocksdb", feature = "lmdb") + )))] + let _ = path; + + DataBaseBuilder { + #[cfg(all( + not(target_arch = "wasm32"), + any(feature = "rocksdb", feature = "lmdb") + ))] + path, histogram_buckets: None, transaction_isolation: None, #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] storage_config: Default::default(), #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] lmdb_config: Default::default(), - }; - builder = builder.register_scala_function(CharLength::new("char_length".to_lowercase())); - builder = - builder.register_scala_function(CharLength::new("character_length".to_lowercase())); - builder = builder.register_scala_function(CurrentDate::new()); - builder = builder.register_scala_function(CurrentTimeStamp::new()); - builder = builder.register_scala_function(Lower::new()); - builder = builder.register_scala_function(OctetLength::new()); - builder = builder.register_scala_function(Upper::new()); - builder = builder.register_table_function(Numbers::new()); - builder + } } /// Sets the default histogram bucket count used by `ANALYZE`. @@ -169,22 +153,6 @@ impl DataBaseBuilder { self } - /// Registers a user-defined scalar function on the database builder. - pub fn register_scala_function(mut self, function: Arc) -> Self { - let summary = function.summary().clone(); - - self.scala_functions.insert(summary, function); - self - } - - /// Registers a user-defined table function on the database builder. - pub fn register_table_function(mut self, function: Arc) -> Self { - let summary = function.summary().clone(); - - self.table_functions.insert(summary, function); - self - } - /// Enables or disables RocksDB statistics collection. #[cfg(all( not(target_arch = "wasm32"), @@ -241,13 +209,7 @@ impl DataBaseBuilder { /// Builds a database using a custom storage implementation. pub fn build_with_storage(self, storage: T) -> Result, DatabaseError> { - Self::_build::( - storage, - self.scala_functions, - self.table_functions, - self.histogram_buckets, - self.transaction_isolation, - ) + Self::_build::(storage, self.histogram_buckets, self.transaction_isolation) } /// Builds a database for the current target platform. @@ -255,13 +217,7 @@ impl DataBaseBuilder { pub fn build(self) -> Result, DatabaseError> { let storage = MemoryStorage::new(); - Self::_build::( - storage, - self.scala_functions, - self.table_functions, - self.histogram_buckets, - self.transaction_isolation, - ) + Self::_build::(storage, self.histogram_buckets, self.transaction_isolation) } /// Builds a RocksDB-backed database. @@ -269,13 +225,7 @@ impl DataBaseBuilder { pub fn build_rocksdb(self) -> Result, DatabaseError> { let storage = RocksStorage::with_config(self.path, self.storage_config)?; - Self::_build::( - storage, - self.scala_functions, - self.table_functions, - self.histogram_buckets, - self.transaction_isolation, - ) + Self::_build::(storage, self.histogram_buckets, self.transaction_isolation) } /// Builds an in-memory database. @@ -284,13 +234,7 @@ impl DataBaseBuilder { pub fn build_in_memory(self) -> Result, DatabaseError> { let storage = MemoryStorage::new(); - Self::_build::( - storage, - self.scala_functions, - self.table_functions, - self.histogram_buckets, - self.transaction_isolation, - ) + Self::_build::(storage, self.histogram_buckets, self.transaction_isolation) } /// Builds a LMDB-backed database. @@ -298,13 +242,7 @@ impl DataBaseBuilder { pub fn build_lmdb(self) -> Result, DatabaseError> { let storage = LmdbStorage::with_config(self.path, self.lmdb_config)?; - Self::_build::( - storage, - self.scala_functions, - self.table_functions, - self.histogram_buckets, - self.transaction_isolation, - ) + Self::_build::(storage, self.histogram_buckets, self.transaction_isolation) } #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] @@ -315,8 +253,6 @@ impl DataBaseBuilder { Self::_build::( storage, - self.scala_functions, - self.table_functions, self.histogram_buckets, self.transaction_isolation, ) @@ -324,8 +260,6 @@ impl DataBaseBuilder { fn _build( storage: T, - scala_functions: ScalaFunctions, - table_functions: TableFunctions, histogram_buckets: Option, transaction_isolation: Option, ) -> Result, DatabaseError> { @@ -337,24 +271,38 @@ impl DataBaseBuilder { let transaction_isolation = transaction_isolation.unwrap_or_else(|| storage.default_transaction_isolation()); storage.validate_transaction_isolation(transaction_isolation)?; - let meta_cache = SharedLruCache::new(256, 8, RandomState::new())?; - let table_cache = SharedLruCache::new(48, 4, RandomState::new())?; - let view_cache = SharedLruCache::new(12, 4, RandomState::new())?; + let meta_cache = HashMap::default(); + let table_cache = HashMap::default(); + let view_cache = HashMap::default(); + let table_arena = TableArenaCell::default(); + + let mut state = State { + scala_functions: Default::default(), + table_functions: Default::default(), + meta_cache, + table_cache, + view_cache, + table_arena, + optimizer_pipeline: default_optimizer_pipeline(), + histogram_buckets, + _p: Default::default(), + }; + + state.load_scalar_function(CharLength::new("char_length".to_lowercase())); + state.load_scalar_function(CharLength::new("character_length".to_lowercase())); + #[cfg(feature = "time")] + state.load_scalar_function(CurrentDate::new()); + #[cfg(feature = "time")] + state.load_scalar_function(CurrentTimeStamp::new()); + state.load_scalar_function(Lower::new()); + state.load_scalar_function(OctetLength::new()); + state.load_scalar_function(Upper::new()); + state.load_table_function(Numbers::new())?; Ok(Database { storage, transaction_isolation, - mdl: Default::default(), - state: Arc::new(State { - scala_functions, - table_functions, - meta_cache, - table_cache, - view_cache, - optimizer_pipeline: default_optimizer_pipeline(), - histogram_buckets, - _p: Default::default(), - }), + state, }) } } @@ -439,12 +387,14 @@ fn default_optimizer_pipeline() -> HepOptimizerPipeline { ImplementationRuleImpl::Values, // DML ImplementationRuleImpl::Analyze, + #[cfg(feature = "copy")] ImplementationRuleImpl::CopyFromFile, + #[cfg(feature = "copy")] ImplementationRuleImpl::CopyToFile, ImplementationRuleImpl::Delete, ImplementationRuleImpl::Insert, ImplementationRuleImpl::Update, - // DLL + // DDL ImplementationRuleImpl::AddColumn, ImplementationRuleImpl::ChangeColumn, ImplementationRuleImpl::CreateTable, @@ -461,6 +411,7 @@ pub(crate) struct State { meta_cache: StatisticsMetaCache, table_cache: TableCache, view_cache: ViewCache, + table_arena: TableArenaCell, optimizer_pipeline: HepOptimizerPipeline, histogram_buckets: Option, _p: PhantomData, @@ -482,37 +433,92 @@ impl State { pub(crate) fn view_cache(&self) -> &ViewCache { &self.view_cache } + pub(crate) fn table_arena(&self) -> &TableArenaCell { + &self.table_arena + } + + fn load_scalar_function(&mut self, function: Arc) { + self.scala_functions + .insert(function.summary().clone(), function); + } + + fn load_table_function( + &mut self, + function: Arc, + ) -> Result<(), DatabaseError> { + let summary = function.summary().clone(); + let mut schema = Schema::new(); + function.output_schema_into(&summary.name, self.table_arena.borrow_mut(), &mut schema); + self.table_functions.insert( + summary, + TableFunctionCatalog { + schema, + inner: ArcTableFunctionImpl(function), + }, + ); + Ok(()) + } + + fn recycle_table_arena(&mut self) { + let mut live_columns = HashSet::new(); + { + let mut live = |column: &crate::catalog::ColumnRef| { + live_columns.insert(column.pos()); + }; + + for table in self.table_cache.values() { + for column in table.columns() { + live(column); + } + } + + let table_arena = self.table_arena.borrow_mut(); + for view in self.view_cache.values() { + view.visit_column_refs(table_arena, &mut live); + } + + for function in self.table_functions.values() { + for column in &function.schema { + live(column); + } + } + } + self.table_arena + .borrow_mut() + .recycle_unreferenced_positions(live_columns); + } - fn build_plan>( - &self, - stmt: &Statement, + pub(crate) fn build_plan<'a, 'txn, A: AsRef<[(&'static str, DataValue)]>, F>( + &'a self, params: A, - transaction: &::TransactionType<'_>, - ) -> Result { - let mut binder = Binder::new( + transaction: &::TransactionType<'txn>, + build: F, + ) -> Result<(LogicalPlan, PlanArena<'a>), DatabaseError> + where + S: 'txn, + F: for<'bind> FnOnce( + &mut Binder<'bind, '_, ::TransactionType<'txn>, A>, + &mut PlanArena<'a>, + ) -> Result, + { + let mut plan_arena = PlanArena::new(self.table_arena()); + let mut binder: Binder<'_, '_, ::TransactionType<'txn>, A> = Binder::new( BinderContext::new( self.table_cache(), self.view_cache(), transaction, self.scala_functions(), self.table_functions(), - Arc::new(AtomicUsize::new(0)), ), ¶ms, None, ); - /// Build a logical plan. - /// - /// SELECT a,b FROM t1 ORDER BY a LIMIT 1; - /// Scan(t1) - /// Sort(a) - /// Limit(1) - /// Project(a,b) - let source_plan = binder.bind(stmt)?; - let mut best_plan = self - .optimizer_pipeline - .instantiate(source_plan) - .find_best(Some(&transaction.meta_loader(self.meta_cache())))?; + let source_plan = build(&mut binder, &mut plan_arena)?; + drop(binder); + let mut best_plan = self.optimizer_pipeline.instantiate(source_plan).find_best( + Some(&StatisticMetaLoader::new(self.meta_cache())), + &mut plan_arena, + )?; if let Operator::Analyze(op) = &mut best_plan.operator { if op.histogram_buckets.is_none() { @@ -520,191 +526,315 @@ impl State { } } - Ok(best_plan) + Ok((best_plan, plan_arena)) } - fn execute<'a, 'txn, A: AsRef<[(&'static str, DataValue)]>>( + pub(crate) fn execute<'a, 'txn, A, F>( &'a self, transaction: &'a mut S::TransactionType<'txn>, - stmt: &Statement, params: A, - ) -> Result<(SchemaRef, Executor<'a, S::TransactionType<'txn>>), DatabaseError> + build: F, + ) -> Result< + ( + Schema, + PlanArena<'a>, + Executor<'a, S::TransactionType<'txn>>, + ), + DatabaseError, + > where S: 'txn, + A: AsRef<[(&'static str, DataValue)]>, + F: for<'bind> FnOnce( + &mut Binder<'bind, '_, S::TransactionType<'txn>, A>, + &mut PlanArena<'a>, + ) -> Result, { transaction.begin_statement_scope()?; match (|| { - let mut plan = self.build_plan(stmt, params, transaction)?; - let schema = plan.output_schema().clone(); - let mut arena = ExecArena::default(); - let root = build_write( - &mut arena, - plan, - ( - &self.table_cache, - &self.view_cache, - &self.meta_cache, - &self.scala_functions, - &self.table_functions, - ), - transaction, + let (mut plan, mut plan_arena) = self.build_plan(params, transaction, build)?; + let schema = plan.take_schema(&mut plan_arena); + let mut arena = ExecArena::new(); + let read_context = ExecutionContext::new( + &self.table_cache, + &self.view_cache, + &self.meta_cache, + &self.scala_functions, + &self.table_functions, ); + let root = build_write(&mut arena, &mut plan_arena, plan, read_context, transaction); let executor = Executor::new(arena, root); - Ok((schema, executor)) + Ok((schema, plan_arena, executor)) })() { Ok(result) => Ok(result), - Err(err) => { - transaction.end_statement_scope()?; - Err(err) + Err(err) => Err(err), + } + } + + pub(crate) fn execute_mut<'a, 'txn, A, F>( + &'a mut self, + transaction: &'a mut S::TransactionType<'txn>, + params: A, + build: F, + ) -> Result< + ( + Schema, + PlanArena<'a>, + Executor<'a, S::TransactionType<'txn>>, + ), + DatabaseError, + > + where + S: 'txn, + A: AsRef<[(&'static str, DataValue)]>, + F: for<'bind> FnOnce( + &mut Binder<'bind, '_, S::TransactionType<'txn>, A>, + &mut PlanArena<'a>, + ) -> Result, + { + transaction.begin_statement_scope()?; + let State { + scala_functions, + table_functions, + meta_cache, + table_cache, + view_cache, + table_arena, + optimizer_pipeline, + histogram_buckets, + .. + } = self; + let mut plan_arena = PlanArena::new(table_arena); + let mut binder = Binder::new( + BinderContext::new( + table_cache, + view_cache, + transaction, + scala_functions, + table_functions, + ), + ¶ms, + None, + ); + let source_plan = build(&mut binder, &mut plan_arena)?; + drop(binder); + let mut plan = optimizer_pipeline + .instantiate(source_plan) + .find_best(Some(&StatisticMetaLoader::new(meta_cache)), &mut plan_arena)?; + + if let Operator::Analyze(op) = &mut plan.operator { + if op.histogram_buckets.is_none() { + op.histogram_buckets = *histogram_buckets; } } + + let schema = plan.take_schema(&mut plan_arena); + let mut arena = ExecArena::new(); + let cache = ExecutionContext::new( + table_cache, + view_cache, + meta_cache, + scala_functions, + table_functions, + ); + let root = build_write(&mut arena, &mut plan_arena, plan, cache, transaction); + let executor = Executor::new(arena, root); + + Ok((schema, plan_arena, executor)) } } /// Main database handle for executing SQL and creating transactions. pub struct Database { pub(crate) storage: S, - transaction_isolation: TransactionIsolationLevel, - mdl: Arc>, - pub(crate) state: Arc>, + pub(crate) transaction_isolation: TransactionIsolationLevel, + pub(crate) state: State, } -impl Database { - /// Runs one or more SQL statements and returns an iterator for the final result set. - /// - /// Earlier statements in the same SQL string are executed eagerly. The last - /// statement is exposed as a streaming iterator. - /// - /// # Examples - /// - /// ```rust - /// use kite_sql::db::{DataBaseBuilder, ResultIter}; - /// - /// let database = DataBaseBuilder::path(".").build_in_memory().unwrap(); - /// database.run("create table t (id int primary key)").unwrap().done().unwrap(); - /// let mut iter = database.run("select * from t").unwrap(); - /// let _schema = iter.schema().clone(); - /// iter.done().unwrap(); - /// ``` - pub fn run>(&self, sql: T) -> Result, DatabaseError> { - let sql = sql.as_ref(); - let statements = prepare_all(sql).map_err(|err| err.with_sql_context(sql))?; - let has_ddl = statements - .iter() - .try_fold(false, |has_ddl, stmt| { - Ok::<_, DatabaseError>(has_ddl || matches!(command_type(stmt)?, CommandType::DDL)) - }) - .map_err(|err| err.with_sql_context(sql))?; - - if statements.len() > 1 && has_ddl { - return Err(DatabaseError::UnsupportedStmt( - "DDL is not allowed in multi-statement execution".to_string(), - ) - .with_sql_context(sql)); - } - - let guard = if has_ddl { - MetaDataLock::Write(self.mdl.write_arc()) - } else { - MetaDataLock::Read(self.mdl.read_arc()) - }; - - let transaction = Box::into_raw(Box::new( - self.storage - .transaction_with_isolation(self.transaction_isolation)?, - )); - let mut statements = statements.into_iter().peekable(); - - while let Some(statement) = statements.next() { - let (schema, executor) = - match self - .state - .execute(unsafe { &mut *transaction }, &statement, &[]) - { - Ok(result) => result, - Err(err) => { - unsafe { drop(Box::from_raw(transaction)) }; - return Err(err.with_sql_context(sql)); - } - }; - - if statements.peek().is_some() { - if let Err(err) = TransactionIter::new(schema, executor, transaction).done() { - unsafe { drop(Box::from_raw(transaction)) }; - return Err(err.with_sql_context(sql)); +impl DDLApply { + fn apply_to( + self, + state: &mut State, + plan_arena: &PlanArena, + ) -> Result { + let mut catalog_changed = false; + match self { + DDLApply::UpsertTable { + table, + clear_statistics, + } => { + let name = table.name().clone(); + let table = table.transplant_to_table_arena(plan_arena)?; + state.table_cache.insert(name.clone(), table); + if clear_statistics { + state + .meta_cache + .retain(|(cached_table_name, _), _| cached_table_name != &name); } - } else { - let inner = Box::into_raw(Box::new(TransactionIter::new( - schema, - executor, - transaction, - ))); - return Ok(DatabaseIter { - transaction, - inner, - _guard: Some(guard), - }); + catalog_changed = true; + } + DDLApply::DropTable { name } => { + state.table_cache.remove(&name); + state + .meta_cache + .retain(|(cached_table_name, _), _| cached_table_name != &name); + catalog_changed = true; + } + DDLApply::UpsertView { view } => { + let name = view.name.clone(); + plan_arena.materialize_into_table_arena(); + state.view_cache.insert(name, view); + catalog_changed = true; + } + DDLApply::DropView { name } => { + state.view_cache.remove(&name); + catalog_changed = true; + } + DDLApply::UpsertStatisticsMeta { + table_name, + index_id, + meta, + } => { + state.meta_cache.insert((table_name, index_id), meta); + } + DDLApply::RemoveStatisticsMeta { + table_name, + index_id, + } => { + state.meta_cache.remove(&(table_name, index_id)); } } - - unsafe { drop(Box::from_raw(transaction)) }; - Err(DatabaseError::EmptyStatement.with_sql_context(sql)) + Ok(catalog_changed) } +} - /// Executes a prepared [`Statement`] inside the current transaction. - pub fn execute>( - &self, - statement: &Statement, +impl Database { + pub(crate) fn execute_mut( + &mut self, + context: &str, params: A, - ) -> Result, DatabaseError> { - let guard = if matches!(command_type(statement)?, CommandType::DDL) { - MetaDataLock::Write(self.mdl.write_arc()) - } else { - MetaDataLock::Read(self.mdl.read_arc()) - }; + build: F, + ) -> Result<(), DatabaseError> + where + A: AsRef<[(&'static str, DataValue)]>, + F: for<'a, 'txn, 'bind> FnOnce( + &mut Binder<'bind, '_, S::TransactionType<'txn>, A>, + &mut PlanArena<'a>, + ) -> Result, + { let transaction = Box::into_raw(Box::new( self.storage .transaction_with_isolation(self.transaction_isolation)?, )); - let (schema, executor) = - match self - .state - .execute(unsafe { &mut *transaction }, statement, params) - { + let state = std::ptr::from_mut(&mut self.state); + let (schema, plan_arena, executor) = + match unsafe { (&mut *state).execute_mut(&mut *transaction, params, build) } { Ok(result) => result, Err(err) => { unsafe { drop(Box::from_raw(transaction)) }; - return Err(err); + return Err(err.with_sql_context(context)); } }; - let inner = Box::into_raw(Box::new(TransactionIter::new( - schema, - executor, - transaction, - ))); - Ok(DatabaseIter { - transaction, - inner, - _guard: Some(guard), + let (plan_arena, apply) = + match TransactionIter::new(schema, plan_arena, executor, transaction) + .done_with_ddl_apply() + { + Ok(apply) => apply, + Err(err) => { + unsafe { drop(Box::from_raw(transaction)) }; + return Err(err.with_sql_context(context)); + } + }; + + if let Err(err) = unsafe { Box::from_raw(transaction).commit() } { + return Err(err.with_sql_context(context)); + } + + let mut catalog_changed = false; + for apply in apply { + catalog_changed |= unsafe { apply.apply_to(&mut *state, &plan_arena) } + .map_err(|err| err.with_sql_context(context))?; + } + if catalog_changed { + unsafe { (&mut *state).recycle_table_arena() }; + } + Ok(()) + } + + pub fn analyze(&mut self, table_name: impl AsRef) -> Result<(), DatabaseError> { + let context = "ANALYZE"; + let table_name: TableName = table_name.as_ref().into(); + self.execute_mut(context, &[], move |binder, arena| { + binder.bind_analyze(table_name, arena) }) } + pub fn load(&mut self, kind: CatalogKind) -> Result<(), DatabaseError> { + match kind { + CatalogKind::ScalarFunction(function) => { + self.state.load_scalar_function(function); + Ok(()) + } + CatalogKind::TableFunction(function) => { + self.state.load_table_function(function)?; + self.state.recycle_table_arena(); + Ok(()) + } + CatalogKind::Table(name) => { + let transaction = self.storage.transaction()?; + let mut table_codec = TableCodec::default(); + let table = transaction + .load_table( + &mut table_codec, + self.state.table_arena.borrow_mut(), + name.clone(), + )? + .ok_or(DatabaseError::TableNotFound)?; + for index in table.indexes() { + let index = self.state.table_arena.borrow().index(*index); + if let Some(meta) = + transaction.statistics_meta(&mut table_codec, name.as_ref(), index.id)? + { + self.state.meta_cache.insert((name.clone(), index.id), meta); + } + } + self.state.table_cache.insert(name, table); + self.state.recycle_table_arena(); + Ok(()) + } + CatalogKind::View(name) => { + let transaction = self.storage.transaction()?; + let mut table_codec = TableCodec::default(); + let view = transaction + .load_view( + &mut table_codec, + &self.state.table_cache, + &self.state.table_arena, + &self.state.scala_functions, + &self.state.table_functions, + name.clone(), + )? + .ok_or(DatabaseError::ViewNotFound)?; + self.state.view_cache.insert(name, view); + self.state.recycle_table_arena(); + Ok(()) + } + } + } + /// Opens a new explicit transaction. /// /// Statements executed through the returned transaction share the same /// transactional context until [`DBTransaction::commit`] is called. pub fn new_transaction(&self) -> Result, DatabaseError> { - let guard = self.mdl.read_arc(); let transaction = self .storage .transaction_with_isolation(self.transaction_isolation)?; - let state = self.state.clone(); Ok(DBTransaction { inner: transaction, - _guard: guard, - state, + state: &self.state, }) } @@ -725,6 +855,60 @@ impl Database { } } +impl<'a, S: Storage> BindSource for &'a Database { + type Iter = DatabaseIter<'a, S>; + type Transaction = S::TransactionType<'a>; + + fn execute(self, params: A, build: F) -> Result + where + A: AsRef<[(&'static str, DataValue)]>, + F: for<'bind> FnOnce( + &mut Binder<'bind, '_, Self::Transaction, A>, + &mut PlanArena<'_>, + ) -> Result, + { + let transaction = Box::into_raw(Box::new( + self.storage + .transaction_with_isolation(self.transaction_isolation)?, + )); + let (schema, plan_arena, executor) = + match self + .state + .execute(unsafe { &mut *transaction }, params, build) + { + Ok(result) => result, + Err(err) => { + unsafe { drop(Box::from_raw(transaction)) }; + return Err(err); + } + }; + let inner = Box::into_raw(Box::new(TransactionIter::new( + schema, + plan_arena, + executor, + transaction, + ))); + Ok(DatabaseIter { transaction, inner }) + } + + #[cfg(feature = "orm")] + fn explain(self, params: A, build: F) -> Result + where + A: AsRef<[(&'static str, DataValue)]>, + F: for<'bind> FnOnce( + &mut Binder<'bind, '_, Self::Transaction, A>, + &mut PlanArena<'_>, + ) -> Result, + { + let mut transaction = self + .storage + .transaction_with_isolation(self.transaction_isolation)?; + transaction.begin_statement_scope()?; + let (plan, mut arena) = self.state.build_plan(params, &transaction, build)?; + Ok(plan.explain(&mut arena, 0)) + } +} + impl Database where S: CheckpointableStorage, @@ -740,27 +924,12 @@ where /// Borrowing interface for result iterators returned by database execution APIs. pub trait BorrowResultIter { - /// Returns the output schema for the current result set. - fn schema(&self) -> &SchemaRef; + /// Borrows the output schema for the current result set. + fn schema(&self, f: impl FnOnce(&SchemaView<'_, '_>) -> R) -> R; /// Returns the next row as a borrowed tuple. fn next_borrowed_tuple(&mut self) -> Result, DatabaseError>; - /// Creates a mapped iterator that transforms borrowed tuples into owned output values. - fn map_result(self, mapper: F) -> MappedResultIter - where - Self: Sized, - F: for<'a> FnMut(&'a SchemaRef, &'a Tuple) -> Result, - { - let schema = self.schema().clone(); - MappedResultIter { - inner: self, - mapper, - schema, - _marker: PhantomData, - } - } - /// Finishes consuming the iterator and flushes any remaining work. fn done(self) -> Result<(), DatabaseError>; } @@ -774,12 +943,12 @@ pub trait ResultIter: BorrowResultIter + Iterator`, which is typically generated by + /// implements `From<(&SchemaView, Tuple)>`, which is typically generated by /// `#[derive(Model)]`. fn orm(self) -> OrmIter where Self: Sized, - T: for<'a> From<(&'a SchemaRef, Tuple)>, + T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>, { OrmIter::new(self) } @@ -787,49 +956,10 @@ pub trait ResultIter: BorrowResultIter + Iterator ResultIter for I where I: BorrowResultIter + Iterator> {} -/// Typed adapter over a borrowing result iterator. -pub struct MappedResultIter { - inner: I, - mapper: F, - schema: SchemaRef, - _marker: PhantomData, -} - -impl MappedResultIter -where - I: BorrowResultIter, - F: for<'a> FnMut(&'a SchemaRef, &'a Tuple) -> Result, -{ - pub fn schema(&self) -> &SchemaRef { - &self.schema - } - - pub fn done(self) -> Result<(), DatabaseError> { - self.inner.done() - } -} - -impl Iterator for MappedResultIter -where - I: BorrowResultIter, - F: for<'a> FnMut(&'a SchemaRef, &'a Tuple) -> Result, -{ - type Item = Result; - - fn next(&mut self) -> Option { - match self.inner.next_borrowed_tuple() { - Ok(Some(tuple)) => Some((self.mapper)(&self.schema, tuple)), - Ok(None) => None, - Err(err) => Some(Err(err)), - } - } -} - #[cfg(feature = "orm")] /// Typed adapter over a [`ResultIter`] that yields ORM models instead of raw tuples. pub struct OrmIter { inner: I, - schema: SchemaRef, _marker: PhantomData, } @@ -837,21 +967,18 @@ pub struct OrmIter { impl OrmIter where I: ResultIter, - T: for<'a> From<(&'a SchemaRef, Tuple)>, + T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>, { fn new(inner: I) -> Self { - let schema = inner.schema().clone(); - Self { inner, - schema, _marker: PhantomData, } } - /// Returns the schema of the underlying result set. - pub fn schema(&self) -> &SchemaRef { - &self.schema + /// Borrows the schema of the underlying result set. + pub fn schema(&self, f: impl FnOnce(&SchemaView<'_, '_>) -> R) -> R { + self.inner.schema(f) } /// Finishes the underlying raw iterator. @@ -864,22 +991,23 @@ where impl Iterator for OrmIter where I: ResultIter, - T: for<'a> From<(&'a SchemaRef, Tuple)>, + T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>, { type Item = Result; fn next(&mut self) -> Option { - self.inner - .next() - .map(|result| result.map(|tuple| T::from((&self.schema, tuple)))) + let tuple = match self.inner.next()? { + Ok(tuple) => tuple, + Err(err) => return Some(Err(err)), + }; + Some(Ok(self.inner.schema(|schema| T::from((schema, tuple))))) } } -/// Raw result iterator returned by [`Database::run`] and [`Database::execute`]. +/// Raw result iterator returned by database execution APIs. pub struct DatabaseIter<'a, S: Storage + 'a> { - transaction: *mut S::TransactionType<'a>, - inner: *mut TransactionIter<'a, S::TransactionType<'a>>, - _guard: Option, + pub(crate) transaction: *mut S::TransactionType<'a>, + pub(crate) inner: *mut TransactionIter<'a, S::TransactionType<'a>>, } impl Drop for DatabaseIter<'_, S> { @@ -895,17 +1023,13 @@ impl Drop for DatabaseIter<'_, S> { impl DatabaseIter<'_, S> { #[inline] - pub fn schema(&self) -> &SchemaRef { - unsafe { (*self.inner).schema() } + pub fn schema(&self, f: impl FnOnce(&SchemaView<'_, '_>) -> R) -> R { + unsafe { (*self.inner).schema(f) } } #[inline] pub fn next_borrowed_tuple(&mut self) -> Result, DatabaseError> { - let result = unsafe { (*self.inner).next_borrowed_tuple() }; - if result.as_ref().is_ok_and(Option::is_none) { - self._guard = None; - } - result + unsafe { (*self.inner).next_borrowed_tuple() } } #[inline] @@ -924,17 +1048,13 @@ impl Iterator for DatabaseIter<'_, S> { type Item = Result; fn next(&mut self) -> Option { - let result = unsafe { (*self.inner).next() }; - if result.is_none() { - self._guard = None; - } - result + unsafe { (*self.inner).next() } } } impl BorrowResultIter for DatabaseIter<'_, S> { - fn schema(&self) -> &SchemaRef { - DatabaseIter::schema(self) + fn schema(&self, f: impl FnOnce(&SchemaView<'_, '_>) -> R) -> R { + DatabaseIter::schema(self, f) } fn next_borrowed_tuple(&mut self) -> Result, DatabaseError> { @@ -948,75 +1068,82 @@ impl BorrowResultIter for DatabaseIter<'_, S> { /// Explicit transaction handle created by [`Database::new_transaction`]. pub struct DBTransaction<'a, S: Storage + 'a> { - inner: S::TransactionType<'a>, - _guard: ArcRwLockReadGuard, - state: Arc>, + pub(crate) inner: S::TransactionType<'a>, + pub(crate) state: &'a State, } impl<'txn, S: Storage> DBTransaction<'txn, S> { - /// Runs SQL inside the current transaction and returns the final result iterator. - pub fn run<'a, T: AsRef>( - &'a mut self, - sql: T, - ) -> Result>, DatabaseError> { - let sql = sql.as_ref(); - let mut statements = prepare_all(sql).map_err(|err| err.with_sql_context(sql))?; - let last_statement = statements - .pop() - .ok_or_else(|| DatabaseError::EmptyStatement.with_sql_context(sql))?; - - for statement in statements { - self.execute(&statement, &[]) - .map_err(|err| err.with_sql_context(sql))? - .done() - .map_err(|err| err.with_sql_context(sql))?; - } + /// Commits the current transaction. + pub fn commit(self) -> Result<(), DatabaseError> { + self.inner.commit()?; - self.execute(&last_statement, &[]) - .map_err(|err| err.with_sql_context(sql)) + Ok(()) } +} - /// Executes a prepared [`Statement`] inside the current transaction. - pub fn execute<'a, A: AsRef<[(&'static str, DataValue)]>>( - &'a mut self, - statement: &Statement, - params: A, - ) -> Result>, DatabaseError> { - if matches!(command_type(statement)?, CommandType::DDL) { - return Err(DatabaseError::UnsupportedStmt( - "`DDL` is not allowed to execute within a transaction".to_string(), - )); - } +impl<'a, 'txn, S: Storage> BindSource for &'a mut DBTransaction<'txn, S> { + type Iter = TransactionIter<'a, S::TransactionType<'txn>>; + type Transaction = S::TransactionType<'txn>; + + fn execute(self, params: A, build: F) -> Result + where + A: AsRef<[(&'static str, DataValue)]>, + F: for<'bind> FnOnce( + &mut Binder<'bind, '_, Self::Transaction, A>, + &mut PlanArena<'_>, + ) -> Result, + { let transaction = std::ptr::from_mut(&mut self.inner); - let (schema, executor) = + let (schema, plan_arena, executor) = self.state - .execute(unsafe { &mut *transaction }, statement, params)?; - Ok(TransactionIter::new(schema, executor, transaction)) + .execute(unsafe { &mut *transaction }, params, build)?; + Ok(TransactionIter::new( + schema, + plan_arena, + executor, + transaction, + )) } - /// Commits the current transaction. - pub fn commit(self) -> Result<(), DatabaseError> { - self.inner.commit()?; - - Ok(()) + #[cfg(feature = "orm")] + fn explain(self, params: A, build: F) -> Result + where + A: AsRef<[(&'static str, DataValue)]>, + F: for<'bind> FnOnce( + &mut Binder<'bind, '_, Self::Transaction, A>, + &mut PlanArena<'_>, + ) -> Result, + { + self.inner.begin_statement_scope()?; + let (plan, mut arena) = self.state.build_plan(params, &self.inner, build)?; + Ok(plan.explain(&mut arena, 0)) } } -/// Raw result iterator returned by [`DBTransaction::run`] and [`DBTransaction::execute`]. +/// Raw result iterator returned by transaction execution APIs. pub struct TransactionIter<'a, T: Transaction + 'a> { executor: Option>, - schema: SchemaRef, + plan_arena: Option>, + schema: Schema, transaction: *mut T, statement_scope_active: bool, + ddl_apply: Vec, } impl<'a, T: Transaction + 'a> TransactionIter<'a, T> { - fn new(schema: SchemaRef, executor: Executor<'a, T>, transaction: *mut T) -> Self { + pub(crate) fn new( + schema: Schema, + plan_arena: PlanArena<'a>, + executor: Executor<'a, T>, + transaction: *mut T, + ) -> Self { Self { executor: Some(executor), + plan_arena: Some(plan_arena), schema, transaction, statement_scope_active: true, + ddl_apply: Vec::new(), } } @@ -1026,14 +1153,21 @@ impl<'a, T: Transaction + 'a> TransactionIter<'a, T> { return Ok(()); } - self.executor.take(); + if let Some(mut executor) = self.executor.take() { + self.ddl_apply.extend(executor.take_ddl_apply()); + } self.statement_scope_active = false; unsafe { (*self.transaction).end_statement_scope() } } #[inline] - pub fn schema(&self) -> &SchemaRef { - &self.schema + pub fn schema(&self, f: impl FnOnce(&SchemaView<'_, '_>) -> R) -> R { + let plan_arena = self + .plan_arena + .as_ref() + .expect("result iterator schema is unavailable after statement completion"); + let schema = SchemaView::new(&self.schema, plan_arena); + f(&schema) } #[inline] @@ -1042,7 +1176,11 @@ impl<'a, T: Transaction + 'a> TransactionIter<'a, T> { return Ok(None); }; let executor_ptr = std::ptr::from_mut(executor); - match unsafe { (*executor_ptr).next_tuple() } { + let plan_arena = self + .plan_arena + .as_mut() + .expect("result iterator plan arena is unavailable after statement completion"); + match unsafe { (*executor_ptr).next_tuple(plan_arena) } { Ok(Some(tuple)) => Ok(Some(tuple)), Ok(None) => { self.finish_statement_scope()?; @@ -1060,6 +1198,16 @@ impl<'a, T: Transaction + 'a> TransactionIter<'a, T> { while self.next_borrowed_tuple()?.is_some() {} Ok(()) } + + fn done_with_ddl_apply(mut self) -> Result<(PlanArena<'a>, Vec), DatabaseError> { + while self.next_borrowed_tuple()?.is_some() {} + Ok(( + self.plan_arena + .take() + .expect("DDL apply plan arena is unavailable after statement completion"), + std::mem::take(&mut self.ddl_apply), + )) + } } impl Drop for TransactionIter<'_, T> { @@ -1074,7 +1222,11 @@ impl Iterator for TransactionIter<'_, T> { fn next(&mut self) -> Option { let result = { let executor = self.executor.as_mut()?; - executor.next_tuple() + let plan_arena = self + .plan_arena + .as_mut() + .expect("result iterator plan arena is unavailable after statement completion"); + executor.next_tuple(plan_arena) }; match result { Ok(Some(tuple)) => Some(Ok(tuple.clone())), @@ -1091,8 +1243,8 @@ impl Iterator for TransactionIter<'_, T> { } impl BorrowResultIter for TransactionIter<'_, T> { - fn schema(&self) -> &SchemaRef { - TransactionIter::schema(self) + fn schema(&self, f: impl FnOnce(&SchemaView<'_, '_>) -> R) -> R { + TransactionIter::schema(self, f) } fn next_borrowed_tuple(&mut self) -> Result, DatabaseError> { @@ -1107,20 +1259,27 @@ impl BorrowResultIter for TransactionIter<'_, T> { #[cfg(all(test, not(target_arch = "wasm32")))] pub(crate) mod test { use crate::binder::{Binder, BinderContext}; - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; + #[cfg(feature = "unsafe_txdb_checkpoint")] + use crate::db::CatalogKind; use crate::db::{BorrowResultIter, DataBaseBuilder, DatabaseError}; use crate::expression::ScalarExpression; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; - use crate::storage::{Storage, TableCache, Transaction, TransactionIsolationLevel}; + use crate::planner::PlanArena; + use crate::storage::{ + table_codec::TableCodec, Storage, TableCache, Transaction, TransactionIsolationLevel, + }; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; use chrono::{Datelike, Local}; use std::io::ErrorKind; + #[cfg(feature = "unsafe_txdb_checkpoint")] use std::sync::atomic::AtomicUsize; #[cfg(feature = "unsafe_txdb_checkpoint")] use std::sync::atomic::Ordering; + #[cfg(feature = "unsafe_txdb_checkpoint")] use std::sync::Arc; #[cfg(feature = "unsafe_txdb_checkpoint")] use std::thread; @@ -1139,8 +1298,9 @@ pub(crate) mod test { } pub(crate) fn build_table( - table_cache: &TableCache, + table_cache: &mut TableCache, transaction: &mut T, + plan_arena: &mut PlanArena, ) -> Result<(), DatabaseError> { let columns = vec![ ColumnCatalog::new( @@ -1159,7 +1319,17 @@ pub(crate) mod test { ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), ), ]; - let _ = transaction.create_table(table_cache, "t1".to_string().into(), columns, false)?; + let mut table_codec = TableCodec::default(); + if let Some(table) = transaction.create_table( + &mut table_codec, + plan_arena, + "t1".to_string().into(), + columns, + false, + )? { + let table = table.transplant_to_table_arena(plan_arena)?; + table_cache.insert(table.name().clone(), table); + } Ok(()) } @@ -1199,11 +1369,8 @@ pub(crate) mod test { #[test] fn test_run_sql() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let database = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - let mut transaction = database.storage.transaction()?; - - build_table(database.state.table_cache(), &mut transaction)?; - transaction.commit()?; + let mut database = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + database.ddl("create table t1(c1 int primary key, c2 boolean, c3 int)")?; for result in database.run("select * from t1")? { println!("{:#?}", result?); @@ -1218,14 +1385,12 @@ pub(crate) mod test { let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; let mut iter = kite_sql.run("select current_date()")?; - assert_eq!( - iter.schema(), - &Arc::new(vec![ColumnRef::from(ColumnCatalog::new( - "current_date()".to_string(), - true, - ColumnDesc::new(LogicalType::Date, None, false, None).unwrap() - ))]) - ); + iter.schema(|schema| { + assert_eq!(schema.len(), 1); + let column = schema.get(0).unwrap(); + assert_eq!(column.name(), "current_date()"); + assert_eq!(column.datatype(), &LogicalType::Date); + }); assert_eq!( iter.next().unwrap()?, Tuple::new( @@ -1247,15 +1412,12 @@ pub(crate) mod test { "SELECT * FROM (select * from table(numbers(10)) a ORDER BY number LIMIT 5) OFFSET 3", )?; - let mut column = ColumnCatalog::new( - "number".to_string(), - true, - ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), - ); - let number_column_id = iter.schema()[0].id().unwrap(); - column.set_ref_table("a".to_string().into(), number_column_id, false); - - assert_eq!(iter.schema(), &Arc::new(vec![ColumnRef::from(column)])); + iter.schema(|schema| { + assert_eq!(schema.len(), 1); + let column = schema.get(0).unwrap(); + assert_eq!(column.name(), "number"); + assert_eq!(column.datatype(), &LogicalType::Integer); + }); assert_eq!( iter.next().unwrap()?, Tuple::new(None, vec![DataValue::Int32(3)]) @@ -1270,14 +1432,10 @@ pub(crate) mod test { #[test] fn test_join_on_alias_right_key_is_localized() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("CREATE TABLE onecolumn (id INT PRIMARY KEY, x INT NULL)")? - .done()?; - kite_sql - .run("CREATE TABLE empty (e_id INT PRIMARY KEY, x INT)")? - .done()?; + kite_sql.ddl("CREATE TABLE onecolumn (id INT PRIMARY KEY, x INT NULL)")?; + kite_sql.ddl("CREATE TABLE empty (e_id INT PRIMARY KEY, x INT)")?; let stmt = crate::db::prepare( "SELECT * FROM onecolumn AS a(aid, x) JOIN empty AS b(bid, y) ON a.x = b.y", @@ -1290,13 +1448,16 @@ pub(crate) mod test { &transaction, kite_sql.state.scala_functions(), kite_sql.state.table_functions(), - Arc::new(AtomicUsize::new(0)), ), &[], None, ); - let source_plan = binder.bind(&stmt)?; - let best_plan = kite_sql.state.build_plan(&stmt, [], &transaction)?; + let mut source_plan_arena = PlanArena::new(kite_sql.state.table_arena()); + let source_plan = binder.bind(&stmt, &mut source_plan_arena)?; + let (best_plan, _best_plan_arena) = + kite_sql + .state + .build_plan([], &transaction, |binder, arena| binder.bind(&stmt, arena))?; let join_plan = match source_plan.operator { Operator::Project(_) => source_plan.childrens.pop_only(), @@ -1354,14 +1515,10 @@ pub(crate) mod test { #[test] fn test_join_on_with_right_filter_keeps_localized_key() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("CREATE TABLE onecolumn (id INT PRIMARY KEY, x INT NULL)")? - .done()?; - kite_sql - .run("CREATE TABLE twocolumn (t_id INT PRIMARY KEY, x INT NULL, y INT NULL)")? - .done()?; + kite_sql.ddl("CREATE TABLE onecolumn (id INT PRIMARY KEY, x INT NULL)")?; + kite_sql.ddl("CREATE TABLE twocolumn (t_id INT PRIMARY KEY, x INT NULL, y INT NULL)")?; let stmt = crate::db::prepare( "SELECT o.x, t.y FROM onecolumn o INNER JOIN twocolumn t ON (o.x=t.x AND t.y=53)", @@ -1374,13 +1531,16 @@ pub(crate) mod test { &transaction, kite_sql.state.scala_functions(), kite_sql.state.table_functions(), - Arc::new(AtomicUsize::new(0)), ), &[], None, ); - let source_plan = binder.bind(&stmt)?; - let best_plan = kite_sql.state.build_plan(&stmt, [], &transaction)?; + let mut source_plan_arena = PlanArena::new(kite_sql.state.table_arena()); + let source_plan = binder.bind(&stmt, &mut source_plan_arena)?; + let (best_plan, _best_plan_arena) = + kite_sql + .state + .build_plan([], &transaction, |binder, arena| binder.bind(&stmt, arena))?; let join_plan = match source_plan.operator { Operator::Project(_) => source_plan.childrens.pop_only(), @@ -1414,12 +1574,12 @@ pub(crate) mod test { unreachable!("expected join filter"); }; let mut referenced_columns = Vec::new(); - filter.visit_referenced_columns(true, &mut |column| { - referenced_columns.push(column.clone()); + filter.visit_referenced_columns(&mut source_plan_arena, &mut |_, column| { + referenced_columns.push(*column); true }); assert_eq!(referenced_columns.len(), 1); - assert_eq!(referenced_columns[0].name(), "y"); + assert_eq!(source_plan_arena.column(referenced_columns[0]).name(), "y"); let join_plan = match best_plan.operator { Operator::Project(_) => best_plan.childrens.pop_only(), @@ -1457,14 +1617,10 @@ pub(crate) mod test { #[test] fn test_join_on_with_right_filter_keeps_localized_key_with_data() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("CREATE TABLE onecolumn (id INT PRIMARY KEY, x INT NULL)")? - .done()?; - kite_sql - .run("CREATE TABLE twocolumn (t_id INT PRIMARY KEY, x INT NULL, y INT NULL)")? - .done()?; + kite_sql.ddl("CREATE TABLE onecolumn (id INT PRIMARY KEY, x INT NULL)")?; + kite_sql.ddl("CREATE TABLE twocolumn (t_id INT PRIMARY KEY, x INT NULL, y INT NULL)")?; kite_sql .run("INSERT INTO onecolumn(id, x) VALUES (0, 44), (1, NULL), (2, 42)")? .done()?; @@ -1478,7 +1634,10 @@ pub(crate) mod test { "SELECT o.x, t.y FROM onecolumn o INNER JOIN twocolumn t ON (o.x=t.x AND t.y=53)", )?; let transaction = kite_sql.storage.transaction()?; - let best_plan = kite_sql.state.build_plan(&stmt, [], &transaction)?; + let (best_plan, _best_plan_arena) = + kite_sql + .state + .build_plan([], &transaction, |binder, arena| binder.bind(&stmt, arena))?; let join_plan = match best_plan.operator { Operator::Project(_) => best_plan.childrens.pop_only(), Operator::Join(_) => best_plan, @@ -1539,11 +1698,9 @@ pub(crate) mod test { #[test] fn test_prepare_statment() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t1 (a int primary key, b int)")? - .done()?; + kite_sql.ddl("create table t1 (a int primary key, b int)")?; kite_sql.run("insert into t1 values(0, 0)")?.done()?; kite_sql.run("insert into t1 values(1, 1)")?.done()?; kite_sql.run("insert into t1 values(2, 2)")?.done()?; @@ -1552,14 +1709,14 @@ pub(crate) mod test { { let statement = crate::db::prepare("explain select * from t1 where b > $1")?; - let mut iter = kite_sql.execute(&statement, &[("$1", DataValue::Int32(0))])?; + let mut iter = kite_sql.execute(statement, &[("$1", DataValue::Int32(0))])?; - assert_eq!( - iter.next().unwrap()?.values[0].utf8().unwrap(), - "Projection [t1.a, t1.b] [Project => (Sort Option: Follow)] - Filter (t1.b > 0), Is Having: false [Filter => (Sort Option: Follow)] - TableScan t1 -> [a, b] [SeqScan => (Sort Option: None)]" - ) + let row = iter.next().unwrap()?; + let plan = row.values[0].utf8().unwrap(); + assert!(plan.contains("Projection")); + assert!(plan.contains("Filter (")); + assert!(plan.contains(" > 0")); + assert!(plan.contains("TableScan t1 -> [#")); } // Aggregate { @@ -1568,7 +1725,7 @@ pub(crate) mod test { )?; let mut iter = kite_sql.execute( - &statement, + statement, &[ ("$1", DataValue::Int32(0)), ("$2", DataValue::Int32(0)), @@ -1576,19 +1733,19 @@ pub(crate) mod test { ("$4", DataValue::Int32(0)), ], )?; - assert_eq!( - iter.next().unwrap()?.values[0].utf8().unwrap(), - "Projection [(t1.a + 0), Max((t1.b + 0))] [Project => (Sort Option: Follow)] - Aggregate [Max((t1.b + 0))] -> Group By [(t1.a + 0)] [HashAggregate => (Sort Option: None)] - Filter (t1.b > 1), Is Having: false [Filter => (Sort Option: Follow)] - TableScan t1 -> [a, b] [SeqScan => (Sort Option: None)]" - ) + let row = iter.next().unwrap()?; + let plan = row.values[0].utf8().unwrap(); + assert!(plan.contains("Projection")); + assert!(plan.contains("Aggregate")); + assert!(plan.contains("Filter (")); + assert!(plan.contains(" > 1")); + assert!(plan.contains("TableScan t1 -> [#")); } { let statement = crate::db::prepare("explain select *, $1 from (select * from t1 where b > $2) left join (select * from t1 where a > $3) on a > $4")?; let mut iter = kite_sql.execute( - &statement, + statement, &[ ("$1", DataValue::Int32(9)), ("$2", DataValue::Int32(0)), @@ -1596,17 +1753,14 @@ pub(crate) mod test { ("$4", DataValue::Int32(0)), ], )?; - assert_eq!( - iter.next().unwrap()?.values[0].utf8().unwrap(), - "Projection [t1.a, t1.b, 9] [Project => (Sort Option: Follow)] - LeftOuter Join Where (t1.a > 0) [NestLoopJoin => (Sort Option: None)] - Projection [t1.a, t1.b] [Project => (Sort Option: Follow)] - Filter (t1.b > 0), Is Having: false [Filter => (Sort Option: Follow)] - TableScan t1 -> [a, b] [SeqScan => (Sort Option: None)] - Projection [t1.a, t1.b] [Project => (Sort Option: Follow)] - Filter (t1.a > 1), Is Having: false [Filter => (Sort Option: Follow)] - TableScan t1 -> [a, b] [SeqScan => (Sort Option: None)]" - ) + let row = iter.next().unwrap()?; + let plan = row.values[0].utf8().unwrap(); + assert!(plan.contains("Projection")); + assert!(plan.contains("LeftOuter Join")); + assert!(plan.contains("9")); + assert!(plan.contains("0")); + assert!(plan.contains("1")); + assert!(plan.contains("TableScan t1 -> [#")); } Ok(()) @@ -1618,32 +1772,38 @@ pub(crate) mod test { #[test] fn test_subquery_explain_uses_parameterized_index_for_in() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + + kite_sql.ddl("create table in_outer(id int primary key, a int)")?; + kite_sql.ddl("create table in_inner(id int primary key, v int)")?; + kite_sql.ddl("create table in_inner_nn(id int primary key, v int)")?; + kite_sql.ddl("create index in_inner_v_index on in_inner(v)")?; + kite_sql.ddl("create index in_inner_nn_v_index on in_inner_nn(v)")?; kite_sql - .run("create table in_outer(id int primary key, a int)")? - .done()?; - kite_sql - .run("create table in_inner(id int primary key, v int)")? - .done()?; - kite_sql - .run("create table in_inner_nn(id int primary key, v int)")? + .run("insert into in_outer values (0, null), (1, 1), (2, 2), (3, 3)")? .done()?; kite_sql - .run("create index in_inner_v_index on in_inner(v)")? + .run("insert into in_inner values (0, 2), (1, null)")? .done()?; kite_sql - .run("create index in_inner_nn_v_index on in_inner_nn(v)")? + .run("insert into in_inner_nn values (0, 2)")? .done()?; + kite_sql.ddl("create table in_outer_flag(id int primary key, a int, b int)")?; + kite_sql.ddl("create table in_inner_flag(id int primary key, v int, flag int)")?; + kite_sql.ddl("create table in_inner_flag_nn(id int primary key, v int, flag int)")?; + kite_sql.ddl("create index in_inner_flag_v_index on in_inner_flag(v)")?; + kite_sql.ddl("create index in_inner_flag_nn_v_index on in_inner_flag_nn(v)")?; + kite_sql - .run("insert into in_outer values (0, null), (1, 1), (2, 2), (3, 3)")? + .run("insert into in_outer_flag values (0, null, 1), (1, 1, 1), (2, 2, 1), (3, 3, 1)")? .done()?; kite_sql - .run("insert into in_inner values (0, 2), (1, null)")? + .run("insert into in_inner_flag values (0, 2, 1), (1, null, 1)")? .done()?; kite_sql - .run("insert into in_inner_nn values (0, 2)")? + .run("insert into in_inner_flag_nn values (0, 2, 1)")? .done()?; let collect_plan = |sql: &str| -> Result { @@ -1662,7 +1822,7 @@ pub(crate) mod test { let collect_ids = |sql: &str| -> Result, DatabaseError> { let mut iter = kite_sql.run(sql)?; let mut ids = Vec::new(); - while let Some(row) = iter.next() { + for row in iter.by_ref() { let row = row?; ids.push(row.values[0].i32().unwrap()); } @@ -1670,35 +1830,30 @@ pub(crate) mod test { Ok(ids) }; - let assert_mark_in_uses_parameterized_index = - |sql: &str, index_name: &str| -> Result<(), DatabaseError> { - let explain_plan = collect_plan(sql)?; - assert!( - explain_plan.contains("MarkAnyApply"), - "unexpected explain plan: {explain_plan}" - ); - assert!( - explain_plan.contains(&format!("IndexScan By {index_name} => Probe")), - "unexpected explain plan: {explain_plan}" - ); - Ok(()) - }; + let assert_mark_in_uses_parameterized_index = |sql: &str| -> Result<(), DatabaseError> { + let explain_plan = collect_plan(sql)?; + assert!( + explain_plan.contains("MarkAnyApply"), + "unexpected explain plan: {explain_plan}" + ); + assert!( + explain_plan.contains("IndexScan By #") && explain_plan.contains("=> Probe"), + "unexpected explain plan: {explain_plan}" + ); + Ok(()) + }; assert_mark_in_uses_parameterized_index( "explain select id from in_outer where a in (select v from in_inner where in_inner.v = in_outer.a)", - "in_inner_v_index", )?; assert_mark_in_uses_parameterized_index( "explain select id from in_outer where a not in (select v from in_inner where in_inner.v = in_outer.a)", - "in_inner_v_index", )?; assert_mark_in_uses_parameterized_index( "explain select id from in_outer where a in (select v from in_inner_nn where in_inner_nn.v = in_outer.a)", - "in_inner_nn_v_index", )?; assert_mark_in_uses_parameterized_index( "explain select id from in_outer where a not in (select v from in_inner_nn where in_inner_nn.v = in_outer.a)", - "in_inner_nn_v_index", )?; assert_eq!( @@ -1726,47 +1881,17 @@ pub(crate) mod test { vec![0, 1, 3] ); - kite_sql - .run("create table in_outer_flag(id int primary key, a int, b int)")? - .done()?; - kite_sql - .run("create table in_inner_flag(id int primary key, v int, flag int)")? - .done()?; - kite_sql - .run("create table in_inner_flag_nn(id int primary key, v int, flag int)")? - .done()?; - kite_sql - .run("create index in_inner_flag_v_index on in_inner_flag(v)")? - .done()?; - kite_sql - .run("create index in_inner_flag_nn_v_index on in_inner_flag_nn(v)")? - .done()?; - - kite_sql - .run("insert into in_outer_flag values (0, null, 1), (1, 1, 1), (2, 2, 1), (3, 3, 1)")? - .done()?; - kite_sql - .run("insert into in_inner_flag values (0, 2, 1), (1, null, 1)")? - .done()?; - kite_sql - .run("insert into in_inner_flag_nn values (0, 2, 1)")? - .done()?; - assert_mark_in_uses_parameterized_index( "explain select id from in_outer_flag where a in (select v from in_inner_flag where in_inner_flag.flag = in_outer_flag.b)", - "in_inner_flag_v_index", )?; assert_mark_in_uses_parameterized_index( "explain select id from in_outer_flag where a not in (select v from in_inner_flag where in_inner_flag.flag = in_outer_flag.b)", - "in_inner_flag_v_index", )?; assert_mark_in_uses_parameterized_index( "explain select id from in_outer_flag where a in (select v from in_inner_flag_nn where in_inner_flag_nn.flag = in_outer_flag.b)", - "in_inner_flag_nn_v_index", )?; assert_mark_in_uses_parameterized_index( "explain select id from in_outer_flag where a not in (select v from in_inner_flag_nn where in_inner_flag_nn.flag = in_outer_flag.b)", - "in_inner_flag_nn_v_index", )?; assert_eq!( @@ -1800,17 +1925,11 @@ pub(crate) mod test { #[test] fn test_subquery_explain_uses_parameterized_index_for_exists() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table exists_outer(id int primary key, a int, b int)")? - .done()?; - kite_sql - .run("create table exists_inner(id int primary key, v int, flag int)")? - .done()?; - kite_sql - .run("create index exists_inner_v_index on exists_inner(v)")? - .done()?; + kite_sql.ddl("create table exists_outer(id int primary key, a int, b int)")?; + kite_sql.ddl("create table exists_inner(id int primary key, v int, flag int)")?; + kite_sql.ddl("create index exists_inner_v_index on exists_inner(v)")?; kite_sql .run("insert into exists_outer values (0, 1, 1), (1, 1, 2), (2, 2, null), (3, 3, 1)")? @@ -1835,7 +1954,7 @@ pub(crate) mod test { let collect_ids = |sql: &str| -> Result, DatabaseError> { let mut iter = kite_sql.run(sql)?; let mut ids = Vec::new(); - while let Some(row) = iter.next() { + for row in iter.by_ref() { let row = row?; ids.push(row.values[0].i32().unwrap()); } @@ -1849,7 +1968,7 @@ pub(crate) mod test { "unexpected explain plan: {explain_plan}" ); assert!( - explain_plan.contains("IndexScan By exists_inner_v_index => Probe"), + explain_plan.contains("IndexScan By #") && explain_plan.contains("=> Probe"), "unexpected explain plan: {explain_plan}" ); Ok(()) @@ -1881,11 +2000,9 @@ pub(crate) mod test { #[test] fn test_run_multi_statement() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t_multi (a int primary key, b int)")? - .done()?; + kite_sql.ddl("create table t_multi (a int primary key, b int)")?; let mut iter = kite_sql.run( "insert into t_multi values(0, 0); insert into t_multi values(1, 1); select * from t_multi order by a", @@ -1915,7 +2032,7 @@ pub(crate) mod test { }; match err { DatabaseError::UnsupportedStmt(msg) => { - assert!(msg.contains("multi-statement execution")); + assert!(msg.contains("DDL and ANALYZE")); } other => panic!("unexpected error type: {other:?}"), } @@ -1926,11 +2043,9 @@ pub(crate) mod test { #[test] fn test_bind_error_with_span() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t_bind_span(id int primary key)")? - .done()?; + kite_sql.ddl("create table t_bind_span(id int primary key)")?; let err = match kite_sql.run("select id, missing_col from t_bind_span") { Ok(_) => panic!("expected bind error"), @@ -1959,11 +2074,9 @@ pub(crate) mod test { #[test] fn test_bind_function_error_with_span() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t_bind_fn_span(id int primary key)")? - .done()?; + kite_sql.ddl("create table t_bind_fn_span(id int primary key)")?; let err = match kite_sql.run("select missing_fn(id) from t_bind_fn_span") { Ok(_) => panic!("expected function bind error"), @@ -1991,11 +2104,9 @@ pub(crate) mod test { #[test] fn test_transaction_sql() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t1 (a int primary key, b int)")? - .done()?; + kite_sql.ddl("create table t1 (a int primary key, b int)")?; let mut tx_1 = kite_sql.new_transaction()?; let mut tx_2 = kite_sql.new_transaction()?; @@ -2038,11 +2149,9 @@ pub(crate) mod test { #[test] fn test_transaction_run_multi_statement() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t_multi_tx (a int primary key, b int)")? - .done()?; + kite_sql.ddl("create table t_multi_tx (a int primary key, b int)")?; let mut tx = kite_sql.new_transaction()?; let mut iter = tx.run( @@ -2073,11 +2182,9 @@ pub(crate) mod test { #[test] fn test_autocommit_read_drops_iterator_before_transaction() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_optimistic()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_optimistic()?; - kite_sql - .run("create table t_iter_drop (a int primary key, b int)")? - .done()?; + kite_sql.ddl("create table t_iter_drop (a int primary key, b int)")?; let mut tx = kite_sql.new_transaction()?; tx.run("insert into t_iter_drop values (0, 0), (1, 1)")? @@ -2093,11 +2200,9 @@ pub(crate) mod test { #[test] fn test_exhausted_database_iter_releases_mdl_guard() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_optimistic()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_optimistic()?; - kite_sql - .run("create table t_iter_guard (a int primary key, b int)")? - .done()?; + kite_sql.ddl("create table t_iter_guard (a int primary key, b int)")?; kite_sql .run("insert into t_iter_guard values (0, 0), (1, 1)")? .done()?; @@ -2112,8 +2217,9 @@ pub(crate) mod test { vec![DataValue::Int32(1), DataValue::Int32(1)] ); assert!(iter.next().is_none()); + iter.done()?; - kite_sql.run("drop table t_iter_guard")?.done()?; + kite_sql.ddl("drop table t_iter_guard")?; Ok(()) } @@ -2121,13 +2227,11 @@ pub(crate) mod test { #[test] fn test_read_committed_refreshes_snapshot_each_statement() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()) + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()) .transaction_isolation(TransactionIsolationLevel::ReadCommitted) .build_rocksdb()?; - kite_sql - .run("create table t_rc (a int primary key, b int)")? - .done()?; + kite_sql.ddl("create table t_rc (a int primary key, b int)")?; kite_sql.run("insert into t_rc values (1, 10)")?.done()?; let mut reader = kite_sql.new_transaction()?; @@ -2166,13 +2270,11 @@ pub(crate) mod test { #[test] fn test_repeatable_read_keeps_transaction_snapshot() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()) + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()) .transaction_isolation(TransactionIsolationLevel::RepeatableRead) .build_rocksdb()?; - kite_sql - .run("create table t_rr (a int primary key, b int)")? - .done()?; + kite_sql.ddl("create table t_rr (a int primary key, b int)")?; kite_sql.run("insert into t_rr values (1, 10)")?.done()?; let mut reader = kite_sql.new_transaction()?; @@ -2198,11 +2300,9 @@ pub(crate) mod test { #[test] fn test_optimistic_transaction_sql() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_optimistic()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_optimistic()?; - kite_sql - .run("create table t1 (a int primary key, b int)")? - .done()?; + kite_sql.ddl("create table t1 (a int primary key, b int)")?; let mut tx_1 = kite_sql.new_transaction()?; let mut tx_2 = kite_sql.new_transaction()?; @@ -2284,11 +2384,9 @@ pub(crate) mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let live_path = temp_dir.path().join("live"); let checkpoint_path = temp_dir.path().join("checkpoint"); - let kite_sql = DataBaseBuilder::path(&live_path).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(&live_path).build_rocksdb()?; - kite_sql - .run("create table t_checkpoint (id int primary key, v int)")? - .done()?; + kite_sql.ddl("create table t_checkpoint (id int primary key, v int)")?; kite_sql .run("insert into t_checkpoint values (1, 10), (2, 20)")? .done()?; @@ -2299,7 +2397,8 @@ pub(crate) mod test { .run("insert into t_checkpoint values (3, 30)")? .done()?; - let snapshot = DataBaseBuilder::path(&checkpoint_path).build_rocksdb()?; + let mut snapshot = DataBaseBuilder::path(&checkpoint_path).build_rocksdb()?; + snapshot.load(CatalogKind::Table("t_checkpoint".to_string().into()))?; assert_eq!( query_i32(&snapshot, "select count(*) from t_checkpoint")?, 2 @@ -2339,11 +2438,9 @@ pub(crate) mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let live_path = temp_dir.path().join("live"); let checkpoint_path = temp_dir.path().join("checkpoint"); - let kite_sql = DataBaseBuilder::path(&live_path).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(&live_path).build_rocksdb()?; - kite_sql - .run("create table t_checkpoint_disabled (id int primary key, v int)")? - .done()?; + kite_sql.ddl("create table t_checkpoint_disabled (id int primary key, v int)")?; let err = kite_sql .checkpoint(&checkpoint_path) @@ -2359,11 +2456,10 @@ pub(crate) mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let live_path = temp_dir.path().join("live"); let checkpoint_path = temp_dir.path().join("checkpoint"); - let kite_sql = Arc::new(DataBaseBuilder::path(&live_path).build_rocksdb()?); + let mut kite_sql = DataBaseBuilder::path(&live_path).build_rocksdb()?; - kite_sql - .run("create table t_checkpoint_concurrent (id int primary key, v int)")? - .done()?; + kite_sql.ddl("create table t_checkpoint_concurrent (id int primary key, v int)")?; + let kite_sql = Arc::new(std::sync::Mutex::new(kite_sql)); let inserted = Arc::new(AtomicUsize::new(0)); let writer_db = Arc::clone(&kite_sql); @@ -2371,6 +2467,8 @@ pub(crate) mod test { let writer = thread::spawn(move || -> Result { for i in 0..64 { writer_db + .lock() + .unwrap() .run(format!( "insert into t_checkpoint_concurrent values ({i}, {i})" ))? @@ -2389,10 +2487,13 @@ pub(crate) mod test { thread::yield_now(); } - kite_sql.checkpoint(&checkpoint_path)?; + kite_sql.lock().unwrap().checkpoint(&checkpoint_path)?; let total = writer.join().expect("writer thread should not panic")?; - let snapshot = DataBaseBuilder::path(&checkpoint_path).build_rocksdb()?; + let mut snapshot = DataBaseBuilder::path(&checkpoint_path).build_rocksdb()?; + snapshot.load(CatalogKind::Table( + "t_checkpoint_concurrent".to_string().into(), + ))?; let snapshot_count = query_i32(&snapshot, "select count(*) from t_checkpoint_concurrent")?; let consistent_count = query_i32( &snapshot, @@ -2403,7 +2504,10 @@ pub(crate) mod test { assert!(snapshot_count <= total as i32); assert_eq!(snapshot_count, consistent_count); assert_eq!( - query_i32(&kite_sql, "select count(*) from t_checkpoint_concurrent")?, + query_i32( + &kite_sql.lock().unwrap(), + "select count(*) from t_checkpoint_concurrent", + )?, total as i32 ); diff --git a/src/errors.rs b/src/errors.rs index 837d7f8c..e4c9a0a5 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -15,9 +15,13 @@ use crate::expression::{BinaryOperator, UnaryOperator}; use crate::types::tuple::TupleId; use crate::types::LogicalType; +#[cfg(feature = "time")] use chrono::ParseError; +#[cfg(feature = "parser")] use sqlparser::parser::ParserError; use std::convert::Infallible; +use std::error::Error; +use std::fmt; use std::num::{ParseFloatError, ParseIntError, TryFromIntError}; use std::str::{ParseBoolError, Utf8Error}; use std::string::FromUtf8Error; @@ -52,221 +56,281 @@ fn format_not_null_message(column: &Option, span: &Option) } } -#[derive(thiserror::Error, Debug)] +#[derive(Debug)] pub enum DatabaseError { - #[error("agg miss: {0}")] AggMiss(String), - #[error("bindcode: {0}")] - Bincode( - #[source] - #[from] - Box, - ), - #[error("cache size overflow")] CacheSizeOverFlow, - #[error( - "cast fail: {from} -> {to}{loc}", - loc = format_sql_error_loc(span) - )] CastFail { from: LogicalType, to: LogicalType, span: Option, }, - #[error("channel close")] ChannelClose, - #[error("columns empty")] ColumnsEmpty, - #[error("column id: `{0}` not found")] ColumnIdNotFound(String), - #[error( - "column: `{name}` not found{loc}", - loc = format_sql_error_loc(span) - )] ColumnNotFound { name: String, span: Option, }, - #[error("csv error: {0}")] - Csv( - #[from] - #[source] - csv::Error, - ), - #[error("default cannot be a column related to the table")] + #[cfg(feature = "copy")] + Csv(csv::Error), DefaultNotColumnRef, - #[error("default does not exist")] DefaultNotExist, - #[error("column: `{0}` already exists")] DuplicateColumn(String), - #[error("table or view: `{0}` hash already exists")] DuplicateSourceHash(String), - #[error("index: `{0}` already exists")] DuplicateIndex(String), - #[error("duplicate primary key")] DuplicatePrimaryKey, - #[error("the column has been declared unique and the value already exists")] DuplicateUniqueValue, - #[error( - "function: `{name}` not found{loc}", - loc = format_sql_error_loc(span) - )] FunctionNotFound { name: String, span: Option, }, - #[error("empty plan")] EmptyPlan, - #[error("sql statement is empty")] EmptyStatement, - #[error("evaluator not found")] EvaluatorNotFound, - #[error("from utf8: {0}")] - FromUtf8Error( - #[source] - #[from] - FromUtf8Error, - ), - #[error("can not compare two types: {0} and {1}")] + FromUtf8Error(FromUtf8Error), Incomparable(LogicalType, LogicalType), - #[error( - "invalid column: `{name}`{loc}", - loc = format_sql_error_loc(span) - )] InvalidColumn { name: String, span: Option, }, - #[error("invalid index")] InvalidIndex, - #[error( - "invalid table: `{name}`{loc}", - loc = format_sql_error_loc(span) - )] InvalidTable { name: String, span: Option, }, - #[error("invalid type")] InvalidType, - #[error("invalid value: {0}")] InvalidValue(String), - #[error("io: {0}")] - IO( - #[source] - #[from] - std::io::Error, - ), - #[error("{0} and {1} do not match")] + IO(std::io::Error), MisMatch(&'static str, &'static str), - #[error("add column must be nullable or specify a default value")] NeedNullAbleOrDefault, - #[error( - "parameter: `{name}` not found{loc}", - loc = format_sql_error_loc(span) - )] ParametersNotFound { name: String, span: Option, }, - #[error("no transaction begin")] NoTransactionBegin, - #[error("{msg}", msg = format_not_null_message(column, span))] NotNull { column: Option, span: Option, }, - #[error("over flow")] OverFlow, - #[error("parser bool: {0}")] - ParseBool( - #[source] - #[from] - ParseBoolError, - ), - #[error("parser date: {0}")] - ParseDate( - #[source] - #[from] - ParseError, - ), - #[error("parser float: {0}")] - ParseFloat( - #[source] - #[from] - ParseFloatError, - ), - #[error("parser int: {0}")] - ParseInt( - #[source] - #[from] - ParseIntError, - ), - #[error("parser sql: {0}")] - ParserSql( - #[source] - #[from] - ParserError, - ), - #[error("must contain primary key!")] + ParseBool(ParseBoolError), + #[cfg(feature = "time")] + ParseDate(ParseError), + ParseFloat(ParseFloatError), + ParseInt(ParseIntError), + #[cfg(feature = "parser")] + ParserSql(ParserError), PrimaryKeyNotFound, - #[error("primaryKey only allows single or multiple values")] PrimaryKeyTooManyLayers, + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + Lmdb(lmdb::Error), #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] - #[error("rocksdb: {0}")] - RocksDB( - #[source] - #[from] - rocksdb::Error, - ), - #[error("the number of caches cannot be divisible by the number of shards")] + RocksDB(rocksdb::Error), SharedNotAlign, - #[error("the table or view not found")] SourceNotFound, - #[error("the table already exists")] TableExists, - #[error("the table not found")] TableNotFound, - #[error("transaction already exists")] TransactionAlreadyExists, - #[error("try from decimal: {0}")] - TryFromDecimal( - #[source] - #[from] - rust_decimal::Error, - ), - #[error("try from int: {0}")] - TryFromInt( - #[source] - #[from] - TryFromIntError, - ), - #[error("too long")] + #[cfg(feature = "decimal")] + TryFromDecimal(rust_decimal::Error), + TryFromInt(TryFromIntError), TooLong, - #[error("tuple id: {0} not found")] TupleIdNotFound(TupleId), - #[error("there are more buckets: {0} than elements: {1}")] TooManyBuckets(usize, usize), - #[error("unsupported unary operator: {0} cannot support {1} for calculations")] UnsupportedUnaryOperator(LogicalType, UnaryOperator), - #[error("unsupported binary operator: {0} cannot support {1} for calculations")] UnsupportedBinaryOperator(LogicalType, BinaryOperator), - #[error("unsupported statement: {0}")] UnsupportedStmt(String), - #[error("utf8: {0}")] - Utf8( - #[source] - #[from] - Utf8Error, - ), - #[error("values length not match, expect {0}, got {1}")] + Utf8(Utf8Error), ValuesLenMismatch(usize, usize), - #[error("the view already exists")] ViewExists, - #[error("the view not found")] ViewNotFound, } +impl fmt::Display for DatabaseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::AggMiss(value) => write!(f, "agg miss: {value}"), + Self::CacheSizeOverFlow => f.write_str("cache size overflow"), + Self::CastFail { from, to, span } => { + write!(f, "cast fail: {from} -> {to}{}", format_sql_error_loc(span)) + } + Self::ChannelClose => f.write_str("channel close"), + Self::ColumnsEmpty => f.write_str("columns empty"), + Self::ColumnIdNotFound(value) => write!(f, "column id: `{value}` not found"), + Self::ColumnNotFound { name, span } => { + write!( + f, + "column: `{name}` not found{}", + format_sql_error_loc(span) + ) + } + #[cfg(feature = "copy")] + Self::Csv(err) => write!(f, "csv error: {err}"), + Self::DefaultNotColumnRef => { + f.write_str("default cannot be a column related to the table") + } + Self::DefaultNotExist => f.write_str("default does not exist"), + Self::DuplicateColumn(value) => write!(f, "column: `{value}` already exists"), + Self::DuplicateSourceHash(value) => { + write!(f, "table or view: `{value}` hash already exists") + } + Self::DuplicateIndex(value) => write!(f, "index: `{value}` already exists"), + Self::DuplicatePrimaryKey => f.write_str("duplicate primary key"), + Self::DuplicateUniqueValue => { + f.write_str("the column has been declared unique and the value already exists") + } + Self::FunctionNotFound { name, span } => { + write!( + f, + "function: `{name}` not found{}", + format_sql_error_loc(span) + ) + } + Self::EmptyPlan => f.write_str("empty plan"), + Self::EmptyStatement => f.write_str("sql statement is empty"), + Self::EvaluatorNotFound => f.write_str("evaluator not found"), + Self::FromUtf8Error(err) => write!(f, "from utf8: {err}"), + Self::Incomparable(left, right) => { + write!(f, "can not compare two types: {left} and {right}") + } + Self::InvalidColumn { name, span } => { + write!(f, "invalid column: `{name}`{}", format_sql_error_loc(span)) + } + Self::InvalidIndex => f.write_str("invalid index"), + Self::InvalidTable { name, span } => { + write!(f, "invalid table: `{name}`{}", format_sql_error_loc(span)) + } + Self::InvalidType => f.write_str("invalid type"), + Self::InvalidValue(value) => write!(f, "invalid value: {value}"), + Self::IO(err) => write!(f, "io: {err}"), + Self::MisMatch(left, right) => write!(f, "{left} and {right} do not match"), + Self::NeedNullAbleOrDefault => { + f.write_str("add column must be nullable or specify a default value") + } + Self::ParametersNotFound { name, span } => { + write!( + f, + "parameter: `{name}` not found{}", + format_sql_error_loc(span) + ) + } + Self::NoTransactionBegin => f.write_str("no transaction begin"), + Self::NotNull { column, span } => f.write_str(&format_not_null_message(column, span)), + Self::OverFlow => f.write_str("over flow"), + Self::ParseBool(err) => write!(f, "parser bool: {err}"), + #[cfg(feature = "time")] + Self::ParseDate(err) => write!(f, "parser date: {err}"), + Self::ParseFloat(err) => write!(f, "parser float: {err}"), + Self::ParseInt(err) => write!(f, "parser int: {err}"), + #[cfg(feature = "parser")] + Self::ParserSql(err) => write!(f, "parser sql: {err}"), + Self::PrimaryKeyNotFound => f.write_str("must contain primary key!"), + Self::PrimaryKeyTooManyLayers => { + f.write_str("primaryKey only allows single or multiple values") + } + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + Self::Lmdb(err) => write!(f, "lmdb: {err}"), + #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] + Self::RocksDB(err) => write!(f, "rocksdb: {err}"), + Self::SharedNotAlign => { + f.write_str("the number of caches cannot be divisible by the number of shards") + } + Self::SourceNotFound => f.write_str("the table or view not found"), + Self::TableExists => f.write_str("the table already exists"), + Self::TableNotFound => f.write_str("the table not found"), + Self::TransactionAlreadyExists => f.write_str("transaction already exists"), + #[cfg(feature = "decimal")] + Self::TryFromDecimal(err) => write!(f, "try from decimal: {err}"), + Self::TryFromInt(err) => write!(f, "try from int: {err}"), + Self::TooLong => f.write_str("too long"), + Self::TupleIdNotFound(value) => write!(f, "tuple id: {value} not found"), + Self::TooManyBuckets(buckets, elements) => { + write!( + f, + "there are more buckets: {buckets} than elements: {elements}" + ) + } + Self::UnsupportedUnaryOperator(ty, op) => { + write!( + f, + "unsupported unary operator: {ty} cannot support {op} for calculations" + ) + } + Self::UnsupportedBinaryOperator(ty, op) => { + write!( + f, + "unsupported binary operator: {ty} cannot support {op} for calculations" + ) + } + Self::UnsupportedStmt(value) => write!(f, "unsupported statement: {value}"), + Self::Utf8(err) => write!(f, "utf8: {err}"), + Self::ValuesLenMismatch(expect, got) => { + write!(f, "values length not match, expect {expect}, got {got}") + } + Self::ViewExists => f.write_str("the view already exists"), + Self::ViewNotFound => f.write_str("the view not found"), + } + } +} + +impl Error for DatabaseError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + #[cfg(feature = "copy")] + Self::Csv(err) => Some(err), + Self::FromUtf8Error(err) => Some(err), + Self::IO(err) => Some(err), + Self::ParseBool(err) => Some(err), + #[cfg(feature = "time")] + Self::ParseDate(err) => Some(err), + Self::ParseFloat(err) => Some(err), + Self::ParseInt(err) => Some(err), + #[cfg(feature = "parser")] + Self::ParserSql(err) => Some(err), + #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] + Self::Lmdb(err) => Some(err), + #[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] + Self::RocksDB(err) => Some(err), + #[cfg(feature = "decimal")] + Self::TryFromDecimal(err) => Some(err), + Self::TryFromInt(err) => Some(err), + Self::Utf8(err) => Some(err), + _ => None, + } + } +} + +macro_rules! impl_from_database_error { + ($source:ty, $variant:ident) => { + impl From<$source> for DatabaseError { + fn from(value: $source) -> Self { + Self::$variant(value) + } + } + }; +} + +#[cfg(feature = "copy")] +impl_from_database_error!(csv::Error, Csv); +impl_from_database_error!(FromUtf8Error, FromUtf8Error); +impl_from_database_error!(std::io::Error, IO); +impl_from_database_error!(ParseBoolError, ParseBool); +#[cfg(feature = "time")] +impl_from_database_error!(ParseError, ParseDate); +impl_from_database_error!(ParseFloatError, ParseFloat); +impl_from_database_error!(ParseIntError, ParseInt); +#[cfg(feature = "parser")] +impl_from_database_error!(ParserError, ParserSql); +#[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] +impl_from_database_error!(lmdb::Error, Lmdb); +#[cfg(all(not(target_arch = "wasm32"), feature = "rocksdb"))] +impl_from_database_error!(rocksdb::Error, RocksDB); +#[cfg(feature = "decimal")] +impl_from_database_error!(rust_decimal::Error, TryFromDecimal); +impl_from_database_error!(TryFromIntError, TryFromInt); +impl_from_database_error!(Utf8Error, Utf8); + impl From for DatabaseError { fn from(value: Infallible) -> Self { match value {} diff --git a/src/execution/ddl/add_column.rs b/src/execution/ddl/add_column.rs index 59276d4f..b2721778 100644 --- a/src/execution/ddl/add_column.rs +++ b/src/execution/ddl/add_column.rs @@ -14,7 +14,9 @@ use super::rewrite_table_in_batches; use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + DDLApply, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, WriteExecutor, +}; use crate::planner::operator::alter_table::add_column::AddColumnOperator; use crate::storage::Transaction; use crate::types::index::{Index, IndexType}; @@ -33,20 +35,25 @@ impl From for AddColumn { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for AddColumn { + type Input = Self; + fn into_executor( - self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::AddColumn(self)) + let executor = input; + arena.push(ExecNode::AddColumn(executor)) } } -impl AddColumn { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for AddColumn { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let table_cache = arena.table_cache(); let Some(AddColumnOperator { @@ -59,12 +66,18 @@ impl AddColumn { return Ok(()); }; - let table_catalog = arena - .transaction_mut() - .table(table_cache, table_name.clone())? - .cloned() - .ok_or(DatabaseError::TableNotFound)?; - if table_catalog.get_column_by_name(column.name()).is_some() { + let (old_schema, pk_ty, column_exists) = { + let table_catalog = arena + .transaction() + .table(table_cache, table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + ( + table_catalog.columns().copied().collect_vec(), + table_catalog.primary_keys_type().clone(), + table_catalog.get_column_by_name(column.name()).is_some(), + ) + }; + if column_exists { if if_not_exists { TupleBuilder::build_result_into(arena.result_tuple_mut(), "1".to_string()); arena.resume(); @@ -73,41 +86,49 @@ impl AddColumn { return Err(DatabaseError::DuplicateColumn(column.name().to_string())); } - let schema = table_catalog.schema_ref().clone(); - let old_deserializers = schema - .iter() - .map(|column_ref| column_ref.datatype().serializable()) - .collect_vec(); - let serializers = schema - .iter() - .map(|column_ref| column_ref.datatype().serializable()) - .chain(::std::iter::once(column.datatype().serializable())) - .collect_vec(); - let pk_ty = table_catalog.primary_keys_type().clone(); let default_value = column.default_value()?; - let col_id = - arena - .transaction_mut() - .add_column(table_cache, &table_name, &column, if_not_exists)?; - let unique_meta = if column.desc().is_unique() { - arena - .transaction_mut() - .table(table_cache, table_name.clone())? - .and_then(|table| table.get_unique_index(&col_id)) - .cloned() - } else { - None + let (unique_index_id, apply) = { + let (transaction, table_codec) = arena.transaction_codec_mut(); + let (table, col_id) = transaction.add_column( + table_codec, + plan_arena, + &table_name, + &column, + if_not_exists, + )?; + let unique_meta = if column.desc().is_unique() { + table + .get_unique_index(&col_id, plan_arena) + .map(|index| plan_arena.index(index).id) + } else { + None + }; + (unique_meta, DDLApply::upsert_table(table, false)) }; + arena.push_ddl_apply(apply); let default_for_index = default_value.clone(); + let mut state = arena.local_state(plan_arena); + let plan_arena = state.plan_arena; + let (transaction, table_codec) = state.transaction_codec_mut(); rewrite_table_in_batches( - arena.transaction_mut(), + transaction, + table_codec, &table_name, &pk_ty, - &old_deserializers, - schema.len(), - &serializers, + old_schema.len(), + || { + old_schema + .iter() + .map(|column| plan_arena.column(*column).datatype().serializable()) + }, + || { + old_schema + .iter() + .map(|column| plan_arena.column(*column).datatype().serializable()) + .chain(::std::iter::once(column.datatype().serializable())) + }, |tuple| { if let Some(value) = &default_value { tuple.values.push(value.clone()); @@ -116,14 +137,14 @@ impl AddColumn { } Ok(()) }, - |transaction, tuple| { - if let (Some(unique_meta), Some(value), Some(tuple_id)) = ( - unique_meta.as_ref(), + |transaction, table_codec, tuple| { + if let (Some(unique_index_id), Some(value), Some(tuple_id)) = ( + unique_index_id.as_ref(), default_for_index.as_ref(), tuple.pk.as_ref(), ) { - let index = Index::new(unique_meta.id, value, IndexType::Unique); - transaction.add_index(&table_name, index, tuple_id)?; + let index = Index::new(*unique_index_id, value, IndexType::Unique); + transaction.add_index(table_codec, &table_name, index, tuple_id)?; } Ok(()) }, diff --git a/src/execution/ddl/change_column.rs b/src/execution/ddl/change_column.rs index 66f43ced..b8f3d026 100644 --- a/src/execution/ddl/change_column.rs +++ b/src/execution/ddl/change_column.rs @@ -14,7 +14,9 @@ use super::{rewrite_table_in_batches, visit_table_in_batches}; use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + DDLApply, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, WriteExecutor, +}; use crate::planner::operator::alter_table::change_column::{ChangeColumnOperator, NotNullChange}; use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; @@ -31,20 +33,25 @@ impl From for ChangeColumn { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for ChangeColumn { + type Input = Self; + fn into_executor( - self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::ChangeColumn(self)) + let executor = input; + arena.push(ExecNode::ChangeColumn(executor)) } } -impl ChangeColumn { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ChangeColumn { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let table_cache = arena.table_cache(); let Some(ChangeColumnOperator { @@ -60,63 +67,76 @@ impl ChangeColumn { return Ok(()); }; - let table_catalog = arena - .transaction_mut() - .table(table_cache, table_name.clone())? - .cloned() - .ok_or(DatabaseError::TableNotFound)?; - let schema = table_catalog.schema_ref().clone(); - let (column_index, old_column) = schema - .iter() - .enumerate() - .find(|(_, column)| column.name() == old_column_name) - .map(|(index, column)| (index, column.clone())) - .ok_or_else(|| DatabaseError::column_not_found(old_column_name.clone()))?; - let needs_data_rewrite = old_column.datatype() != &data_type; + let (old_schema, pk_ty, column_index, old_column_type, old_column_id, affected_index_name) = { + let table_catalog = arena + .transaction() + .table(table_cache, table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + let (column_index, old_column) = table_catalog + .columns() + .enumerate() + .find_map(|(index, column)| { + let column = plan_arena.column(*column); + (column.name() == old_column_name).then_some((index, column)) + }) + .ok_or_else(|| DatabaseError::column_not_found(old_column_name.clone()))?; + let old_column_id = old_column.id(); + let affected_index_name = old_column_id.and_then(|column_id| { + table_catalog + .indexes() + .map(|index_meta| plan_arena.index(*index_meta)) + .find(|index_meta| index_meta.column_ids.contains(&column_id)) + .map(|index_meta| index_meta.name.clone()) + }); + ( + table_catalog.columns().copied().collect_vec(), + table_catalog.primary_keys_type().clone(), + column_index, + old_column.datatype().clone(), + old_column_id, + affected_index_name, + ) + }; + let needs_data_rewrite = old_column_type != data_type; let needs_not_null_validation = matches!(not_null_change, NotNullChange::Set); if needs_data_rewrite { - let Some(column_id) = old_column.id() else { + let Some(_) = old_column_id else { return Err(DatabaseError::column_not_found(old_column_name.clone())); }; - let affected_index = table_catalog - .indexes() - .find(|index_meta| index_meta.column_ids.contains(&column_id)); - if let Some(index_meta) = affected_index { + if let Some(index_name) = affected_index_name { return Err(DatabaseError::UnsupportedStmt(format!( - "cannot alter type of indexed column `{}`; drop index `{}` first", - old_column_name, index_meta.name + "cannot alter type of indexed column `{old_column_name}`; drop index `{index_name}` first" ))); } } - let old_deserializers = schema - .iter() - .map(|column| column.datatype().serializable()) - .collect_vec(); - let pk_ty = table_catalog.primary_keys_type().clone(); - if needs_data_rewrite { - let serializers = schema - .iter() - .enumerate() - .map(|(index, column)| { - if index == column_index { - data_type.serializable() - } else { - column.datatype().serializable() - } - }) - .collect_vec(); let target_column_name = new_column_name.clone(); let target_data_type = data_type.clone(); + let mut state = arena.local_state(plan_arena); + let plan_arena = state.plan_arena; + let (transaction, table_codec) = state.transaction_codec_mut(); rewrite_table_in_batches( - arena.transaction_mut(), + transaction, + table_codec, &table_name, &pk_ty, - &old_deserializers, - schema.len(), - &serializers, + old_schema.len(), + || { + old_schema + .iter() + .map(|column| plan_arena.column(*column).datatype().serializable()) + }, + || { + old_schema.iter().enumerate().map(|(index, column)| { + if index == column_index { + data_type.serializable() + } else { + plan_arena.column(*column).datatype().serializable() + } + }) + }, |tuple| { tuple.values[column_index] = tuple.values[column_index].clone().cast(&target_data_type)?; @@ -125,16 +145,24 @@ impl ChangeColumn { } Ok(()) }, - |_, _| Ok(()), + |_, _, _| Ok(()), )?; } else if needs_not_null_validation { let target_column_name = new_column_name.clone(); + let mut state = arena.local_state(plan_arena); + let plan_arena = state.plan_arena; + let (transaction, table_codec) = state.transaction_codec_mut(); visit_table_in_batches( - arena.transaction(), + transaction, + table_codec, &table_name, &pk_ty, - &old_deserializers, - schema.len(), + old_schema.len(), + || { + old_schema + .iter() + .map(|column| plan_arena.column(*column).datatype().serializable()) + }, |tuple| { if tuple.values[column_index].is_null() { return Err(DatabaseError::not_null_column(target_column_name.clone())); @@ -144,15 +172,21 @@ impl ChangeColumn { )?; } - arena.transaction_mut().change_column( - table_cache, - &table_name, - &old_column_name, - &new_column_name, - &data_type, - &default_change, - ¬_null_change, - )?; + let apply = { + let (transaction, table_codec) = arena.transaction_codec_mut(); + let table = transaction.change_column( + table_codec, + plan_arena, + &table_name, + &old_column_name, + &new_column_name, + &data_type, + &default_change, + ¬_null_change, + )?; + DDLApply::upsert_table(table, true) + }; + arena.push_ddl_apply(apply); TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{table_name}")); arena.resume(); diff --git a/src/execution/ddl/create_index.rs b/src/execution/ddl/create_index.rs index 525583fb..340d0b13 100644 --- a/src/execution/ddl/create_index.rs +++ b/src/execution/ddl/create_index.rs @@ -13,30 +13,31 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::dql::projection::Projection; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + build_read, with_projection_tmp_value, DDLApply, ExecArena, ExecId, ExecNode, ExecutionContext, + ExecutorNode, WriteExecutor, +}; use crate::expression::ScalarExpression; use crate::planner::operator::create_index::CreateIndexOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::index::Index; -use crate::types::tuple::SchemaRef; +use crate::types::tuple::Schema; use crate::types::tuple_builder::TupleBuilder; -use crate::types::value::DataValue; use crate::types::ColumnId; pub struct CreateIndex { op: Option, - input_schema: SchemaRef, + input_schema: Schema, input_plan: LogicalPlan, input: ExecId, } impl From<(CreateIndexOperator, LogicalPlan)> for CreateIndex { - fn from((op, mut input): (CreateIndexOperator, LogicalPlan)) -> Self { + fn from((op, input): (CreateIndexOperator, LogicalPlan)) -> Self { Self { op: Some(op), - input_schema: input.output_schema().clone(), + input_schema: Default::default(), input_plan: input, input: 0, } @@ -44,24 +45,34 @@ impl From<(CreateIndexOperator, LogicalPlan)> for CreateIndex { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CreateIndex { + type Input = Self; + fn into_executor( - mut self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - self.input = build_read(arena, self.input_plan.take(), cache, transaction); - arena.push(ExecNode::CreateIndex(self)) + let mut executor = input; + executor.input_schema = executor.input_plan.take_schema(plan_arena); + executor.input = build_read( + arena, + plan_arena, + executor.input_plan.take(), + cache, + transaction, + ); + arena.push(ExecNode::CreateIndex(executor)) } } -impl CreateIndex { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for CreateIndex { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { - let table_cache = arena.table_cache(); - let Some(CreateIndexOperator { table_name, index_name, @@ -74,10 +85,21 @@ impl CreateIndex { return Ok(()); }; + if if_not_exists + && arena.table_cache().get(&table_name).is_some_and(|table| { + table + .indexes() + .any(|index| plan_arena.index(*index).name == index_name) + }) + { + arena.finish(); + return Ok(()); + } + let (column_ids, column_exprs): (Vec, Vec) = columns .into_iter() .filter_map(|column| { - column.id().and_then(|id| { + plan_arena.column(column).id().and_then(|id| { self.input_schema .iter() .position(|schema_column| schema_column == &column) @@ -85,13 +107,20 @@ impl CreateIndex { }) }) .unzip(); - let index_id = match arena.transaction_mut().add_index_meta( - table_cache, - &table_name, - index_name, - column_ids, - ty, - ) { + let index_id_result = { + let (transaction, table_codec) = arena.transaction_codec_mut(); + let (table, index_id) = transaction.add_index_meta( + table_codec, + plan_arena, + &table_name, + index_name, + column_ids, + ty, + )?; + arena.push_ddl_apply(DDLApply::upsert_table(table, false)); + Ok(index_id) + }; + let index_id = match index_id_result { Ok(index_id) => index_id, Err(DatabaseError::DuplicateIndex(index_name)) => { if if_not_exists { @@ -104,23 +133,16 @@ impl CreateIndex { Err(err) => return Err(err), }; - while arena.next_tuple(self.input)? { - let (value, tuple_pk) = { - let tuple = arena.result_tuple(); - let Some(value) = - DataValue::values_to_tuple(Projection::projection(tuple, &column_exprs)?) - else { - continue; - }; - let Some(tuple_pk) = tuple.pk.clone() else { - continue; - }; - (value, tuple_pk) + while arena.next_tuple(self.input, plan_arena)? { + let Some(tuple_pk) = arena.result_tuple().pk.clone() else { + continue; }; - let index = Index::new(index_id, &value, ty); - arena - .transaction_mut() - .add_index(table_name.as_ref(), index, &tuple_pk)?; + with_projection_tmp_value(arena, None, &column_exprs, |arena, value| { + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + let index = Index::new(index_id, value, ty); + transaction.add_index(table_codec, table_name.as_ref(), index, &tuple_pk) + })?; } TupleBuilder::build_result_into(arena.result_tuple_mut(), "1".to_string()); diff --git a/src/execution/ddl/create_table.rs b/src/execution/ddl/create_table.rs index 8ee40b20..fc275e7f 100644 --- a/src/execution/ddl/create_table.rs +++ b/src/execution/ddl/create_table.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + DDLApply, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, WriteExecutor, +}; use crate::planner::operator::create_table::CreateTableOperator; use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; @@ -29,20 +31,25 @@ impl From for CreateTable { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CreateTable { + type Input = Self; + fn into_executor( - self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::CreateTable(self)) + let executor = input; + arena.push(ExecNode::CreateTable(executor)) } } -impl CreateTable { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for CreateTable { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(CreateTableOperator { table_name, @@ -54,12 +61,17 @@ impl CreateTable { return Ok(()); }; - arena.transaction_mut().create_table( - arena.table_cache(), + let (transaction, table_codec) = arena.transaction_codec_mut(); + let table = transaction.create_table( + table_codec, + plan_arena, table_name.clone(), columns, if_not_exists, )?; + if let Some(table) = table { + arena.push_ddl_apply(DDLApply::upsert_table(table, false)); + } TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{table_name}")); arena.resume(); diff --git a/src/execution/ddl/create_view.rs b/src/execution/ddl/create_view.rs index a199667b..c676c83c 100644 --- a/src/execution/ddl/create_view.rs +++ b/src/execution/ddl/create_view.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + DDLApply, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, WriteExecutor, +}; use crate::planner::operator::create_view::CreateViewOperator; use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; @@ -29,29 +31,34 @@ impl From for CreateView { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CreateView { + type Input = Self; + fn into_executor( - self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::CreateView(self)) + let executor = input; + arena.push(ExecNode::CreateView(executor)) } } -impl CreateView { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for CreateView { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(CreateViewOperator { view, or_replace }) = self.op.take() else { arena.finish(); return Ok(()); }; let view_name = view.name.to_string(); - arena - .transaction_mut() - .create_view(arena.view_cache(), view, or_replace)?; + let (transaction, table_codec) = arena.transaction_codec_mut(); + let view = transaction.create_view(table_codec, plan_arena, view, or_replace)?; + arena.push_ddl_apply(DDLApply::upsert_view(view)); TupleBuilder::build_result_into(arena.result_tuple_mut(), view_name); arena.resume(); diff --git a/src/execution/ddl/drop_column.rs b/src/execution/ddl/drop_column.rs index b03aa858..a94669a2 100644 --- a/src/execution/ddl/drop_column.rs +++ b/src/execution/ddl/drop_column.rs @@ -14,7 +14,9 @@ use super::rewrite_table_in_batches; use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + DDLApply, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, WriteExecutor, +}; use crate::planner::operator::alter_table::drop_column::DropColumnOperator; use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; @@ -31,23 +33,27 @@ impl From for DropColumn { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropColumn { + type Input = Self; + fn into_executor( - self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::DropColumn(self)) + let executor = input; + arena.push(ExecNode::DropColumn(executor)) } } -impl DropColumn { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for DropColumn { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let table_cache = arena.table_cache(); - let meta_cache = arena.meta_cache(); let Some(DropColumnOperator { table_name, column_name, @@ -58,53 +64,65 @@ impl DropColumn { return Ok(()); }; - let table_catalog = arena - .transaction_mut() - .table(table_cache, table_name.clone())? - .cloned() - .ok_or(DatabaseError::TableNotFound)?; - let tuple_columns = table_catalog.schema_ref().clone(); - if let Some((column_index, is_primary)) = tuple_columns - .iter() - .enumerate() - .find(|(_, column)| column.name() == column_name) - .map(|(i, column)| (i, column.desc().is_primary())) - { + let (old_schema, pk_ty, column_info) = { + let table_catalog = arena + .transaction() + .table(table_cache, table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + let column_info = table_catalog + .columns() + .enumerate() + .find_map(|(index, column)| { + let column = plan_arena.column(*column); + (column.name() == column_name).then(|| (index, column.desc().is_primary())) + }); + ( + table_catalog.columns().copied().collect_vec(), + table_catalog.primary_keys_type().clone(), + column_info, + ) + }; + if let Some((column_index, is_primary)) = column_info { if is_primary { return Err(DatabaseError::invalid_column( "drop of primary key column is not allowed.".to_owned(), )); } - let old_deserializers = tuple_columns - .iter() - .map(|column| column.datatype().serializable()) - .collect_vec(); - let serializers = tuple_columns - .iter() - .enumerate() - .filter(|(i, _)| *i != column_index) - .map(|(_, column)| column.datatype().serializable()) - .collect_vec(); - let pk_ty = table_catalog.primary_keys_type().clone(); - rewrite_table_in_batches( - arena.transaction_mut(), - &table_name, - &pk_ty, - &old_deserializers, - tuple_columns.len(), - &serializers, - |tuple| { - let _ = tuple.values.remove(column_index); - Ok(()) - }, - |_, _| Ok(()), - )?; - arena.transaction_mut().drop_column( - table_cache, - meta_cache, - &table_name, - &column_name, - )?; + { + let mut state = arena.local_state(plan_arena); + let plan_arena = state.plan_arena; + let (transaction, table_codec) = state.transaction_codec_mut(); + rewrite_table_in_batches( + transaction, + table_codec, + &table_name, + &pk_ty, + old_schema.len(), + || { + old_schema + .iter() + .map(|column| plan_arena.column(*column).datatype().serializable()) + }, + || { + old_schema + .iter() + .enumerate() + .filter(|(index, _)| *index != column_index) + .map(|(_, column)| plan_arena.column(*column).datatype().serializable()) + }, + |tuple| { + let _ = tuple.values.remove(column_index); + Ok(()) + }, + |_, _, _| Ok(()), + )?; + } + { + let (transaction, table_codec) = arena.transaction_codec_mut(); + let table = + transaction.drop_column(table_codec, plan_arena, &table_name, &column_name)?; + arena.push_ddl_apply(DDLApply::upsert_table(table, true)); + } TupleBuilder::build_result_into(arena.result_tuple_mut(), "1".to_string()); arena.resume(); diff --git a/src/execution/ddl/drop_index.rs b/src/execution/ddl/drop_index.rs index 04b31f8c..e84062bd 100644 --- a/src/execution/ddl/drop_index.rs +++ b/src/execution/ddl/drop_index.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + DDLApply, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, WriteExecutor, +}; use crate::planner::operator::drop_index::DropIndexOperator; use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; @@ -29,20 +31,25 @@ impl From for DropIndex { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropIndex { + type Input = Self; + fn into_executor( - self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::DropIndex(self)) + let executor = input; + arena.push(ExecNode::DropIndex(executor)) } } -impl DropIndex { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for DropIndex { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(DropIndexOperator { table_name, @@ -54,13 +61,23 @@ impl DropIndex { return Ok(()); }; - arena.transaction_mut().drop_index( - arena.table_cache(), - arena.meta_cache(), - table_name, - &index_name, - if_exists, - )?; + let dropped = { + let (transaction, table_codec) = arena.transaction_codec_mut(); + transaction.drop_index( + table_codec, + plan_arena, + table_name.clone(), + &index_name, + if_exists, + )? + }; + if let Some((table, index_id)) = dropped { + arena.push_ddl_apply(DDLApply::upsert_table(table, false)); + arena.push_ddl_apply(DDLApply::RemoveStatisticsMeta { + table_name: table_name.clone(), + index_id, + }); + } TupleBuilder::build_result_into(arena.result_tuple_mut(), index_name.to_string()); arena.resume(); diff --git a/src/execution/ddl/drop_table.rs b/src/execution/ddl/drop_table.rs index 4597d94a..a802f17b 100644 --- a/src/execution/ddl/drop_table.rs +++ b/src/execution/ddl/drop_table.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + DDLApply, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, WriteExecutor, +}; use crate::planner::operator::drop_table::DropTableOperator; use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; @@ -29,20 +31,25 @@ impl From for DropTable { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropTable { + type Input = Self; + fn into_executor( - self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::DropTable(self)) + let executor = input; + arena.push(ExecNode::DropTable(executor)) } } -impl DropTable { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for DropTable { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + _: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(DropTableOperator { table_name, @@ -53,9 +60,12 @@ impl DropTable { return Ok(()); }; - arena - .transaction_mut() - .drop_table(arena.table_cache(), table_name.clone(), if_exists)?; + let (transaction, table_codec) = arena.transaction_codec_mut(); + if transaction.drop_table(table_codec, table_name.clone(), if_exists)? { + arena.push_ddl_apply(DDLApply::DropTable { + name: table_name.clone(), + }); + } TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{table_name}")); arena.resume(); diff --git a/src/execution/ddl/drop_view.rs b/src/execution/ddl/drop_view.rs index 04f80fdd..5f34f2ce 100644 --- a/src/execution/ddl/drop_view.rs +++ b/src/execution/ddl/drop_view.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + DDLApply, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, WriteExecutor, +}; use crate::planner::operator::drop_view::DropViewOperator; use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; @@ -29,20 +31,25 @@ impl From for DropView { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropView { + type Input = Self; + fn into_executor( - self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::DropView(self)) + let executor = input; + arena.push(ExecNode::DropView(executor)) } } -impl DropView { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for DropView { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + _: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(DropViewOperator { view_name, @@ -53,10 +60,12 @@ impl DropView { return Ok(()); }; - let view_cache = arena.view_cache(); - arena - .transaction_mut() - .drop_view(view_cache, view_name.clone(), if_exists)?; + let (transaction, table_codec) = arena.transaction_codec_mut(); + if transaction.drop_view(table_codec, view_name.clone(), if_exists)? { + arena.push_ddl_apply(DDLApply::DropView { + name: view_name.clone(), + }); + } TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{view_name}")); arena.resume(); diff --git a/src/execution/ddl/mod.rs b/src/execution/ddl/mod.rs index f2423d8b..e0141708 100644 --- a/src/execution/ddl/mod.rs +++ b/src/execution/ddl/mod.rs @@ -35,19 +35,24 @@ use std::collections::Bound; const REWRITE_BATCH_SIZE: usize = 1024; #[allow(clippy::too_many_arguments)] -fn read_tuple_batch( +fn read_tuple_batch( transaction: &T, + table_codec: &mut TableCodec, table_name: &TableName, pk_ty: &LogicalType, - old_deserializers: &[TupleValueSerializableImpl], old_values_len: usize, + old_deserializers: &mut D, start_after: Option<&TupleId>, batch: &mut Vec, batch_size: usize, -) -> Result { - let table_codec = unsafe { &*transaction.table_codec() }; +) -> Result +where + T: Transaction, + D: FnMut() -> I, + I: IntoIterator, +{ let lower = if let Some(last_pk) = start_after { - table_codec.with_tuple_key_unchecked(table_name.as_ref(), last_pk, |key| { + table_codec.with_tuple_unchecked(table_name.as_ref(), last_pk, None, |key, _| { Ok::<_, DatabaseError>(Bound::Excluded(key.to_vec())) })? } else { @@ -81,7 +86,7 @@ fn read_tuple_batch( }; TableCodec::decode_tuple_into( tuple, - old_deserializers, + old_deserializers(), Some(tuple_id), value, old_values_len, @@ -93,16 +98,19 @@ fn read_tuple_batch( }) } -pub(crate) fn visit_table_in_batches( +pub(crate) fn visit_table_in_batches( transaction: &T, + table_codec: &mut TableCodec, table_name: &TableName, pk_ty: &LogicalType, - old_deserializers: &[TupleValueSerializableImpl], old_values_len: usize, + mut old_deserializers: D, mut visit: F, ) -> Result<(), DatabaseError> where T: Transaction, + D: FnMut() -> I, + I: IntoIterator, F: FnMut(&Tuple) -> Result<(), DatabaseError>, { let mut last_pk = None; @@ -111,10 +119,11 @@ where loop { let batch_len = read_tuple_batch( transaction, + table_codec, table_name, pk_ty, - old_deserializers, old_values_len, + &mut old_deserializers, last_pk.as_ref(), &mut batch, REWRITE_BATCH_SIZE, @@ -137,20 +146,25 @@ where } #[allow(clippy::too_many_arguments)] -pub(crate) fn rewrite_table_in_batches( +pub(crate) fn rewrite_table_in_batches( transaction: &mut T, + table_codec: &mut TableCodec, table_name: &TableName, pk_ty: &LogicalType, - old_deserializers: &[TupleValueSerializableImpl], old_values_len: usize, - new_serializers: &[TupleValueSerializableImpl], + mut old_deserializers: D, + mut new_serializers: S, mut rewrite: F, mut after_write: G, ) -> Result<(), DatabaseError> where T: Transaction, + D: FnMut() -> I, + I: IntoIterator, + S: FnMut() -> J, + J: IntoIterator, F: FnMut(&mut Tuple) -> Result<(), DatabaseError>, - G: FnMut(&mut T, &Tuple) -> Result<(), DatabaseError>, + G: FnMut(&mut T, &mut TableCodec, &Tuple) -> Result<(), DatabaseError>, { let mut last_pk = None; let mut batch = Vec::with_capacity(REWRITE_BATCH_SIZE); @@ -158,10 +172,11 @@ where loop { let batch_len = read_tuple_batch( transaction, + table_codec, table_name, pk_ty, - old_deserializers, old_values_len, + &mut old_deserializers, last_pk.as_ref(), &mut batch, REWRITE_BATCH_SIZE, @@ -173,8 +188,14 @@ where for tuple in batch.iter_mut().take(batch_len) { rewrite(tuple)?; - transaction.append_tuple(table_name.as_ref(), tuple.clone(), new_serializers, true)?; - after_write(transaction, tuple)?; + transaction.append_tuple( + table_codec, + table_name.as_ref(), + tuple.clone(), + new_serializers(), + true, + )?; + after_write(transaction, table_codec, tuple)?; } if batch_len < REWRITE_BATCH_SIZE { diff --git a/src/execution/ddl/truncate.rs b/src/execution/ddl/truncate.rs index ac5ac4d3..49277bba 100644 --- a/src/execution/ddl/truncate.rs +++ b/src/execution/ddl/truncate.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, WriteExecutor, +}; use crate::planner::operator::truncate::TruncateOperator; use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; @@ -29,26 +31,33 @@ impl From for Truncate { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Truncate { + type Input = Self; + fn into_executor( - self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::Truncate(self)) + let executor = input; + arena.push(ExecNode::Truncate(executor)) } } -impl Truncate { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Truncate { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(TruncateOperator { table_name }) = self.op.take() else { arena.finish(); return Ok(()); }; - arena.transaction_mut().drop_data(&table_name)?; + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + transaction.drop_data(table_codec, &table_name)?; TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{table_name}")); arena.resume(); diff --git a/src/execution/ddl_apply.rs b/src/execution/ddl_apply.rs new file mode 100644 index 00000000..7809a3e9 --- /dev/null +++ b/src/execution/ddl_apply.rs @@ -0,0 +1,56 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::catalog::view::View; +use crate::catalog::{TableCatalog, TableName}; +use crate::optimizer::core::statistics_meta::StatisticsMeta; +use crate::types::index::IndexId; + +pub(crate) enum DDLApply { + UpsertTable { + table: TableCatalog, + clear_statistics: bool, + }, + DropTable { + name: TableName, + }, + UpsertView { + view: View, + }, + DropView { + name: TableName, + }, + UpsertStatisticsMeta { + table_name: TableName, + index_id: IndexId, + meta: StatisticsMeta, + }, + RemoveStatisticsMeta { + table_name: TableName, + index_id: IndexId, + }, +} + +impl DDLApply { + pub(crate) fn upsert_table(table: TableCatalog, clear_statistics: bool) -> Self { + Self::UpsertTable { + table, + clear_statistics, + } + } + + pub(crate) fn upsert_view(view: View) -> Self { + Self::UpsertView { view } + } +} diff --git a/src/execution/dml/analyze.rs b/src/execution/dml/analyze.rs index b512f746..b30f2921 100644 --- a/src/execution/dml/analyze.rs +++ b/src/execution/dml/analyze.rs @@ -14,14 +14,16 @@ use crate::catalog::TableName; use crate::errors::DatabaseError; -use crate::execution::dql::projection::Projection; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + build_read, with_projection_tmp_value, DDLApply, ExecArena, ExecId, ExecNode, ExecutionContext, + ExecutorNode, WriteExecutor, +}; use crate::expression::ScalarExpression; use crate::optimizer::core::histogram::HistogramBuilder; use crate::optimizer::core::statistics_meta::StatisticsMeta; use crate::planner::operator::analyze::AnalyzeOperator; use crate::planner::LogicalPlan; -use crate::storage::{StatisticsMetaCache, Transaction}; +use crate::storage::{table_codec::TableCodec, Transaction}; use crate::types::index::IndexId; use crate::types::value::{DataValue, Utf8Type}; use crate::types::CharLengthUnits; @@ -59,65 +61,70 @@ impl From<(AnalyzeOperator, LogicalPlan)> for Analyze { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Analyze { + type Input = Self; + fn into_executor( - mut self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - self.input = Some(build_read( + let mut executor = input; + executor.input = Some(build_read( arena, - self.input_plan.take(), + plan_arena, + executor.input_plan.take(), cache, transaction, )); - arena.push(ExecNode::Analyze(self)) + arena.push(ExecNode::Analyze(executor)) } } -impl Analyze { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Analyze { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(input) = self.input.take() else { arena.finish(); return Ok(()); }; - let mut builders = Vec::new(); - let table = arena - .transaction_mut() - .table(arena.table_cache(), self.table_name.clone())? - .cloned() - .ok_or(DatabaseError::TableNotFound)?; - - for index in table.indexes() { - builders.push(State { - index_id: index.id, - exprs: index.column_exprs(&table)?, - builder: HistogramBuilder::new(index, None), - histogram_buckets: self.histogram_buckets, - }); - } + let mut builders = { + let table = arena + .transaction() + .table(arena.table_cache(), self.table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + table + .indexes() + .map(|index| { + let index = plan_arena.index(*index); + Ok(State { + index_id: index.id, + exprs: index.column_exprs(table, plan_arena)?, + builder: HistogramBuilder::new(index, None), + histogram_buckets: self.histogram_buckets, + }) + }) + .collect::, DatabaseError>>()? + }; - while arena.next_tuple(input)? { - let tuple = arena.result_tuple(); + while arena.next_tuple(input, plan_arena)? { for State { exprs, builder, .. } in builders.iter_mut() { - let values = Projection::projection(tuple, exprs)?; - - if values.len() == 1 { - builder.append(&values[0])?; - } else { - builder.append(&DataValue::Tuple(values, false))?; - } + with_projection_tmp_value(arena, None, exprs, |_, value| builder.append(value))?; } } + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec, ddl_apply) = state.write_transaction_codec_ddl_apply_mut(); let values = Self::persist_statistics_meta( &self.table_name, builders, - arena.meta_cache(), - arena.transaction_mut(), + ddl_apply, + transaction, + table_codec, )?; let output = arena.result_tuple_mut(); @@ -139,8 +146,9 @@ impl Analyze { fn persist_statistics_meta( table_name: &TableName, builders: Vec, - cache: &StatisticsMetaCache, + applies: &mut Vec, transaction: &mut U, + table_codec: &mut TableCodec, ) -> Result, DatabaseError> { let mut values = Vec::with_capacity(builders.len()); @@ -155,7 +163,12 @@ impl Analyze { builder.build(histogram_buckets.unwrap_or(DEFAULT_NUM_OF_BUCKETS))?; let meta = StatisticsMeta::new(histogram, sketch); - transaction.save_statistics_meta(cache, table_name, meta)?; + transaction.save_statistics_meta(table_codec, table_name, meta.clone())?; + applies.push(DDLApply::UpsertStatisticsMeta { + table_name: table_name.clone(), + index_id, + meta, + }); values.push(DataValue::Utf8 { value: format!("{table_name}/{index_id}"), ty: Utf8Type::Variable(None), @@ -169,7 +182,7 @@ impl Analyze { impl fmt::Display for AnalyzeOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let indexes = self.index_metas.iter().map(|index| &index.name).join(", "); + let indexes = self.index_metas.iter().join(", "); write!(f, "Analyze {} -> [{}]", self.table_name, indexes)?; @@ -179,12 +192,13 @@ impl fmt::Display for AnalyzeOperator { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use crate::db::DataBaseBuilder; + use crate::db::{DataBaseBuilder, Database}; use crate::errors::DatabaseError; use crate::execution::dml::analyze::DEFAULT_NUM_OF_BUCKETS; use crate::expression::range_detacher::Range; use crate::optimizer::core::cm_sketch::COUNT_MIN_SKETCH_STORAGE_PAGE_LEN; - use crate::storage::{InnerIter, Storage, Transaction}; + use crate::optimizer::core::statistics_meta::StatisticMetaLoader; + use crate::storage::{table_codec::TableCodec, InnerIter, Storage, Transaction}; use crate::types::value::DataValue; use std::ops::Bound; use tempfile::TempDir; @@ -199,29 +213,40 @@ mod test { Ok(()) } + fn create_table( + database: &mut Database, + sql: &str, + ) -> Result<(), DatabaseError> { + database.ddl(sql) + } + + fn create_index( + database: &mut Database, + sql: &str, + ) -> Result<(), DatabaseError> { + database.ddl(sql) + } + fn test_statistics_meta_roundtrip() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let buckets = 10; - let kite_sql = DataBaseBuilder::path(temp_dir.path()) + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()) .histogram_buckets(buckets) .build_rocksdb()?; - kite_sql - .run("create table t1 (a int primary key, b int)")? - .done()?; - kite_sql.run("create index b_index on t1 (b)")?.done()?; - kite_sql.run("create index p_index on t1 (a, b)")?.done()?; + create_table(&mut kite_sql, "create table t1 (a int primary key, b int)")?; + create_index(&mut kite_sql, "create index b_index on t1 (b)")?; + create_index(&mut kite_sql, "create index p_index on t1 (a, b)")?; for i in 0..DEFAULT_NUM_OF_BUCKETS + 1 { kite_sql .run(format!("insert into t1 values({i}, {})", i % 20))? .done()?; } - kite_sql.run("analyze table t1")?.done()?; + kite_sql.analyze("t1")?; - let transaction = kite_sql.storage.transaction()?; let table_name = "t1".to_string().into(); - let loader = transaction.meta_loader(kite_sql.state.meta_cache()); + let loader = StatisticMetaLoader::new(kite_sql.state.meta_cache()); let statistics_meta_pk_index = loader.load(&table_name, 0)?.unwrap(); assert_eq!(statistics_meta_pk_index.index_id(), 0); @@ -243,46 +268,40 @@ mod test { fn test_meta_loader_uses_cache() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t1 (a int primary key, b int)")? - .done()?; - kite_sql.run("create index b_index on t1 (b)")?.done()?; + create_table(&mut kite_sql, "create table t1 (a int primary key, b int)")?; + create_index(&mut kite_sql, "create index b_index on t1 (b)")?; for i in 0..DEFAULT_NUM_OF_BUCKETS + 1 { kite_sql .run(format!("insert into t1 values({i}, {i})"))? .done()?; } - kite_sql.run("analyze table t1")?.done()?; + kite_sql.analyze("t1")?; let table_name = "t1".to_string().into(); - let transaction = kite_sql.storage.transaction()?; - let loader = transaction.meta_loader(kite_sql.state.meta_cache()); + let loader = StatisticMetaLoader::new(kite_sql.state.meta_cache()); assert!(loader.load(&table_name, 1)?.is_some()); assert_eq!( loader.collect_count(&table_name, 1, &Range::Eq(DataValue::Int32(7)))?, Some(1) ); - drop(transaction); - let mut transaction = kite_sql.storage.transaction()?; - let keys: Vec> = unsafe { &*transaction.table_codec() } - .with_statistics_index_bound("t1", 1, |min, max| { - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; - let mut keys = Vec::new(); - while let Some((key, _)) = iter.try_next()? { - keys.push(key.to_vec()); - } - Ok(keys) - })?; + let mut table_codec = TableCodec::default(); + let keys: Vec> = table_codec.with_statistics_index_bound("t1", 1, |min, max| { + let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut keys = Vec::new(); + while let Some((key, _)) = iter.try_next()? { + keys.push(key.to_vec()); + } + Ok(keys) + })?; for key in keys { transaction.remove(&key)?; } - let transaction = kite_sql.storage.transaction()?; - let loader = transaction.meta_loader(kite_sql.state.meta_cache()); + let loader = StatisticMetaLoader::new(kite_sql.state.meta_cache()); assert!(loader.load(&table_name, 1)?.is_some()); assert_eq!( loader.collect_count(&table_name, 1, &Range::Eq(DataValue::Int32(7)))?, @@ -294,72 +313,64 @@ mod test { fn test_meta_loader_negative_cache() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t1 (a int primary key, b int)")? - .done()?; - kite_sql.run("create index b_index on t1 (b)")?.done()?; + create_table(&mut kite_sql, "create table t1 (a int primary key, b int)")?; + create_index(&mut kite_sql, "create index b_index on t1 (b)")?; let table_name = "t1".to_string().into(); - let transaction = kite_sql.storage.transaction()?; - let loader = transaction.meta_loader(kite_sql.state.meta_cache()); + let loader = StatisticMetaLoader::new(kite_sql.state.meta_cache()); assert!(loader.load(&table_name, 1)?.is_none()); - let entry = kite_sql - .state - .meta_cache() - .get(&(table_name.clone(), 1)) - .expect("missing statistics cache entry"); - assert!(entry.is_none()); + assert!(!kite_sql.state.meta_cache().contains_key(&(table_name, 1))); Ok(()) } fn test_clean_expired_index() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t1 (a int primary key, b int)")? - .done()?; - kite_sql.run("create index b_index on t1 (b)")?.done()?; - kite_sql.run("create index p_index on t1 (a, b)")?.done()?; + create_table(&mut kite_sql, "create table t1 (a int primary key, b int)")?; + create_index(&mut kite_sql, "create index b_index on t1 (b)")?; + create_index(&mut kite_sql, "create index p_index on t1 (a, b)")?; for i in 0..DEFAULT_NUM_OF_BUCKETS + 1 { kite_sql .run(format!("insert into t1 values({i}, {i})"))? .done()?; } - kite_sql.run("analyze table t1")?.done()?; + kite_sql.analyze("t1")?; - let transaction = kite_sql.storage.transaction()?; - let count = - unsafe { &*transaction.table_codec() }.with_statistics_bound("t1", |min, max| { + let count = { + let transaction = kite_sql.storage.transaction()?; + let mut table_codec = TableCodec::default(); + table_codec.with_statistics_bound("t1", |min, max| { let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; let mut count = 0; while iter.try_next()?.is_some() { count += 1; } Ok(count) - })?; + })? + }; assert!(count > 3); - kite_sql.run("alter table t1 drop column b")?.done()?; - kite_sql.run("analyze table t1")?.done()?; + kite_sql.ddl("alter table t1 drop column b")?; + kite_sql.analyze("t1")?; let transaction = kite_sql.storage.transaction()?; - let keys = - unsafe { &*transaction.table_codec() }.with_statistics_bound("t1", |min, max| { - let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; - let mut keys = 0; - while iter.try_next()?.is_some() { - keys += 1; - } - Ok(keys) - })?; + let mut table_codec = TableCodec::default(); + let keys = table_codec.with_statistics_bound("t1", |min, max| { + let mut iter = transaction.range(Bound::Included(min), Bound::Included(max))?; + let mut keys = 0; + while iter.try_next()?.is_some() { + keys += 1; + } + Ok(keys) + })?; let table_name = "t1".to_string().into(); - let loader = transaction.meta_loader(kite_sql.state.meta_cache()); + let loader = StatisticMetaLoader::new(kite_sql.state.meta_cache()); let statistics_meta = loader.load(&table_name, 0)?.unwrap(); let expected_keys = 1 + 1 diff --git a/src/execution/dml/copy_from_file.rs b/src/execution/dml/copy_from_file.rs index 7a478ceb..76c4cde4 100644 --- a/src/execution/dml/copy_from_file.rs +++ b/src/execution/dml/copy_from_file.rs @@ -13,12 +13,12 @@ // limitations under the License. use crate::binder::copy::FileFormat; -use crate::catalog::PrimaryKeyIndices; use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, WriteExecutor, +}; use crate::planner::operator::copy_from_file::CopyFromFileOperator; use crate::storage::Transaction; -use crate::types::tuple::Tuple; use crate::types::tuple_builder::TupleBuilder; use itertools::Itertools; use std::fs::File; @@ -35,35 +35,45 @@ impl From for CopyFromFile { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CopyFromFile { + type Input = Self; + fn into_executor( - self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::CopyFromFile(self)) + let executor = input; + arena.push(ExecNode::CopyFromFile(executor)) } } -impl CopyFromFile { - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for CopyFromFile { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(op) = self.op.take() else { arena.finish(); return Ok(()); }; - let serializers = op + let column_types = op .schema_ref .iter() - .map(|column| column.datatype().serializable()) + .map(|column| plan_arena.column(*column).datatype().clone()) + .collect_vec(); + let serializers = column_types + .iter() + .map(|ty| ty.serializable()) .collect_vec(); - let table = arena - .transaction_mut() - .table(arena.table_cache(), op.table.clone())? + let table_cache = arena.context().table_cache(); + let transaction = arena.transaction(); + let table = transaction + .table(table_cache, op.table.clone())? .ok_or(DatabaseError::TableNotFound)?; - let primary_keys_indices = table.primary_keys_indices().clone(); + let table_name = table.name().to_string(); let file = File::open(op.source.path)?; let mut buf_reader = BufReader::new(file); @@ -82,7 +92,7 @@ impl CopyFromFile { }; let column_count = op.schema_ref.len(); - let tuple_builder = TupleBuilder::new(&op.schema_ref, Some(&primary_keys_indices)); + let tuple_builder = TupleBuilder::new(column_types, Some(table.primary_key_indices())); let mut size = 0_usize; for record in reader.records() { @@ -95,9 +105,9 @@ impl CopyFromFile { } let chunk = tuple_builder.build_with_row(record.iter())?; - arena - .transaction_mut() - .append_tuple(table.name(), chunk, &serializers, false)?; + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + transaction.append_tuple(table_codec, &table_name, chunk, &serializers, false)?; size += 1; } @@ -105,63 +115,19 @@ impl CopyFromFile { arena.resume(); Ok(()) } - - #[allow(dead_code)] - fn read_file_blocking( - mut self, - tx: std::sync::mpsc::Sender, - pk_indices: PrimaryKeyIndices, - ) -> Result<(), DatabaseError> { - let Some(op) = self.op.take() else { - return Ok(()); - }; - let file = File::open(op.source.path)?; - let mut buf_reader = BufReader::new(file); - let mut reader = match op.source.format { - FileFormat::Csv { - delimiter, - quote, - escape, - header, - } => csv::ReaderBuilder::new() - .delimiter(delimiter as u8) - .quote(quote as u8) - .escape(escape.map(|c| c as u8)) - .has_headers(header) - .from_reader(&mut buf_reader), - }; - - let column_count = op.schema_ref.len(); - let tuple_builder = TupleBuilder::new(&op.schema_ref, Some(&pk_indices)); - - for record in reader.records() { - let record = record?; - - if !(record.len() == column_count - || record.len() == column_count + 1 && record.get(column_count) == Some("")) - { - return Err(DatabaseError::MisMatch("columns", "values")); - } - - tx.send(tuple_builder.build_with_row(record.iter())?) - .map_err(|_| DatabaseError::ChannelClose)?; - } - Ok(()) - } } #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { use super::*; use crate::binder::copy::ExtSource; - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary}; - use crate::db::DataBaseBuilder; + use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::db::{CatalogKind, DataBaseBuilder}; use crate::errors::DatabaseError; use crate::storage::Storage; use crate::types::CharLengthUnits; use crate::types::LogicalType; use std::io::Write; - use std::sync::Arc; use tempfile::TempDir; use ulid::Ulid; @@ -172,51 +138,34 @@ mod tests { let mut file = tempfile::NamedTempFile::new().expect("failed to create temp file"); write!(file, "{csv}").expect("failed to write file"); - let columns = vec![ - ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "a".to_string(), - relation: ColumnRelation::Table { - column_id: Ulid::new(), - table_name: "t1".to_string().into(), - is_temp: false, - }, - }, - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None)?, - false, - )), - ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "b".to_string(), - relation: ColumnRelation::Table { - column_id: Ulid::new(), - table_name: "t1".to_string().into(), - is_temp: false, - }, - }, - false, - ColumnDesc::new(LogicalType::Float, None, false, None)?, - false, - )), - ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c".to_string(), - relation: ColumnRelation::Table { - column_id: Ulid::new(), - table_name: "t1".to_string().into(), - is_temp: false, - }, - }, - false, - ColumnDesc::new( - LogicalType::Varchar(Some(10), CharLengthUnits::Characters), - None, - false, - None, - )?, + let tmp_dir = TempDir::new()?; + let mut db = DataBaseBuilder::path(tmp_dir.path()).build_rocksdb()?; + db.ddl("create table test_copy (a int primary key, b float, c varchar(10))")?; + db.load(CatalogKind::Table("test_copy".to_string().into()))?; + + fn test_column( + name: &str, + ty: LogicalType, + primary_key: Option, + ) -> Result { + let mut column = ColumnCatalog::new( + name.to_string(), false, - )), + ColumnDesc::new(ty, primary_key, false, None)?, + ); + column.set_ref_table("t1".to_string().into(), Ulid::new(), false); + Ok(column) + } + + let mut plan_arena = crate::planner::PlanArena::new(db.state.table_arena()); + let columns = vec![ + plan_arena.alloc_column(test_column("a", LogicalType::Integer, Some(0))?), + plan_arena.alloc_column(test_column("b", LogicalType::Float, None)?), + plan_arena.alloc_column(test_column( + "c", + LogicalType::Varchar(Some(10), CharLengthUnits::Characters), + None, + )?), ]; let op = CopyFromFileOperator { @@ -230,24 +179,19 @@ mod tests { header: false, }, }, - schema_ref: Arc::new(columns), + schema_ref: columns, }; - let tmp_dir = TempDir::new()?; - let db = DataBaseBuilder::path(tmp_dir.path()).build_rocksdb()?; - db.run("create table test_copy (a int primary key, b float, c varchar(10))")? - .done()?; - - let storage = db.storage; - let mut transaction = storage.transaction()?; + let transaction = db.storage.transaction()?; let mut executor = crate::execution::execute_mut( CopyFromFile::from(op), - ( + crate::execution::test_utils::empty_context( db.state.table_cache(), db.state.view_cache(), db.state.meta_cache(), ), - &mut transaction, + plan_arena, + &transaction, ); let result = executor.next().expect("copy from file should yield once")?; diff --git a/src/execution/dml/copy_to_file.rs b/src/execution/dml/copy_to_file.rs index 79dc6879..19b18fc1 100644 --- a/src/execution/dml/copy_to_file.rs +++ b/src/execution/dml/copy_to_file.rs @@ -14,15 +14,19 @@ use crate::binder::copy::FileFormat; use crate::errors::DatabaseError; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ReadExecutor}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::planner::operator::copy_to_file::CopyToFileOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::tuple_builder::TupleBuilder; +use itertools::Itertools; pub struct CopyToFile { op: CopyToFileOperator, input_plan: LogicalPlan, + column_names: Vec, input: Option, } @@ -31,32 +35,68 @@ impl From<(CopyToFileOperator, LogicalPlan)> for CopyToFile { CopyToFile { op, input_plan: input, + column_names: Default::default(), input: None, } } } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for CopyToFile { + type Input = Self; + fn into_executor( - mut self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - self.input = Some(build_read( + let mut executor = input; + executor.column_names = executor + .input_plan + .take_schema(plan_arena) + .into_iter() + .map(|column| plan_arena.column(column).name().to_string()) + .collect_vec(); + executor.input = Some(build_read( arena, - self.input_plan.take(), + plan_arena, + executor.input_plan.take(), cache, transaction, )); - arena.push(ExecNode::CopyToFile(self)) + arena.push(ExecNode::CopyToFile(executor)) } } impl CopyToFile { - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( + fn create_writer(&self) -> Result, DatabaseError> { + let mut writer = match self.op.target.format { + FileFormat::Csv { + delimiter, + quote, + header, + .. + } => csv::WriterBuilder::new() + .delimiter(delimiter as u8) + .quote(quote as u8) + .has_headers(header) + .from_path(self.op.target.path.clone())?, + }; + + if let FileFormat::Csv { header: true, .. } = self.op.target.format { + writer.write_record(&self.column_names)?; + } + + Ok(writer) + } +} + +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for CopyToFile { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(input) = self.input.take() else { arena.finish(); @@ -64,7 +104,7 @@ impl CopyToFile { }; let mut writer = self.create_writer()?; - while arena.next_tuple(input)? { + while arena.next_tuple(input, plan_arena)? { let tuple = arena.result_tuple(); writer.write_record( tuple @@ -76,103 +116,29 @@ impl CopyToFile { } writer.flush().map_err(DatabaseError::from)?; - TupleBuilder::build_result_into(arena.result_tuple_mut(), format!("{}", self.op)); + let message = if self.column_names.is_empty() { + format!("{}", self.op) + } else { + format!("{} [{}]", self.op, self.column_names.iter().format(", ")) + }; + TupleBuilder::build_result_into(arena.result_tuple_mut(), message); arena.resume(); Ok(()) } - - fn create_writer(&self) -> Result, DatabaseError> { - let mut writer = match self.op.target.format { - FileFormat::Csv { - delimiter, - quote, - header, - .. - } => csv::WriterBuilder::new() - .delimiter(delimiter as u8) - .quote(quote as u8) - .has_headers(header) - .from_path(self.op.target.path.clone())?, - }; - - if let FileFormat::Csv { header: true, .. } = self.op.target.format { - let headers = self - .op - .schema_ref - .iter() - .map(|c| c.name()) - .collect::>(); - writer.write_record(headers)?; - } - - Ok(writer) - } } #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { use super::*; use crate::binder::copy::ExtSource; - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary}; - use crate::db::DataBaseBuilder; + use crate::db::{CatalogKind, DataBaseBuilder}; use crate::errors::DatabaseError; use crate::planner::operator::table_scan::TableScanOperator; use crate::storage::Storage; - use crate::types::CharLengthUnits; - use crate::types::LogicalType; - use std::sync::Arc; use tempfile::TempDir; - use ulid::Ulid; #[test] fn read_csv() -> Result<(), DatabaseError> { - let columns = vec![ - ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "a".to_string(), - relation: ColumnRelation::Table { - column_id: Ulid::new(), - table_name: "t1".to_string().into(), - is_temp: false, - }, - }, - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None)?, - false, - )), - ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "b".to_string(), - relation: ColumnRelation::Table { - column_id: Ulid::new(), - table_name: "t1".to_string().into(), - is_temp: false, - }, - }, - false, - ColumnDesc::new(LogicalType::Float, None, false, None)?, - false, - )), - ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c".to_string(), - relation: ColumnRelation::Table { - column_id: Ulid::new(), - table_name: "t1".to_string().into(), - is_temp: false, - }, - }, - false, - ColumnDesc::new( - LogicalType::Varchar(Some(10), CharLengthUnits::Characters), - None, - false, - None, - )?, - false, - )), - ]; - let tmp_dir = TempDir::new()?; let file_path = tmp_dir.path().join("test.csv"); @@ -186,36 +152,42 @@ mod tests { header: true, }, }, - schema_ref: Arc::new(columns), }; let temp_dir = TempDir::new().unwrap(); - let db = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - db.run("create table t1 (a int primary key, b float, c varchar(10))")? - .done()?; + let mut db = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + db.ddl("create table t1 (a int primary key, b float, c varchar(10))")?; + db.load(CatalogKind::Table("t1".to_string().into()))?; db.run("insert into t1 values (1, 1.1, 'foo')")?.done()?; db.run("insert into t1 values (2, 2.0, 'fooo')")?.done()?; db.run("insert into t1 values (3, 2.1, 'Kite')")?.done()?; - let storage = db.storage; - let mut transaction = storage.transaction()?; + let plan_arena = crate::planner::PlanArena::new(db.state.table_arena()); + let transaction = db.storage.transaction()?; let table = transaction .table(db.state.table_cache(), "t1".to_string().into())? .unwrap(); let executor = CopyToFile { op: op.clone(), - input_plan: TableScanOperator::build("t1".to_string().into(), table, true)?, + input_plan: TableScanOperator::build( + "t1".to_string().into(), + table, + true, + &plan_arena, + )?, + column_names: Default::default(), input: None, }; let mut executor = crate::execution::execute( executor, - ( + crate::execution::test_utils::empty_context( db.state.table_cache(), db.state.view_cache(), db.state.meta_cache(), ), - &mut transaction, + plan_arena, + &transaction, ); let tuple = executor.next().expect("executor should yield once")?; @@ -234,7 +206,7 @@ mod tests { let record3 = records.next().unwrap()?; assert_eq!(record3, vec!["3", "2.1", "Kite"]); - assert_eq!(tuple.values[0].to_string(), format!("{op}")); + assert_eq!(tuple.values[0].to_string(), format!("{op} [a, b, c]")); Ok(()) } } diff --git a/src/execution/dml/delete.rs b/src/execution/dml/delete.rs index b925730b..5fd87880 100644 --- a/src/execution/dml/delete.rs +++ b/src/execution/dml/delete.rs @@ -14,16 +14,15 @@ use crate::catalog::TableName; use crate::errors::DatabaseError; -use crate::execution::dql::projection::Projection; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; -use crate::expression::ScalarExpression; +use crate::execution::{ + build_read, with_projection_tmp_value, ExecArena, ExecId, ExecNode, ExecutionContext, + ExecutorNode, WriteExecutor, +}; use crate::planner::operator::delete::DeleteOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; -use crate::types::index::{Index, IndexId, IndexType}; +use crate::types::index::Index; use crate::types::tuple_builder::TupleBuilder; -use crate::types::value::DataValue; -use std::collections::HashMap; pub struct Delete { table_name: TableName, @@ -42,90 +41,76 @@ impl From<(DeleteOperator, LogicalPlan)> for Delete { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Delete { + type Input = Self; + fn into_executor( - mut self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - self.input = Some(build_read( + let mut executor = input; + executor.input = Some(build_read( arena, - self.input_plan.take(), + plan_arena, + executor.input_plan.take(), cache, transaction, )); - arena.push(ExecNode::Delete(self)) + arena.push(ExecNode::Delete(executor)) } } -impl Delete { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Delete { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(input) = self.input.take() else { arena.finish(); return Ok(()); }; - let table = arena - .transaction_mut() - .table(arena.table_cache(), self.table_name.clone())? - .ok_or(DatabaseError::TableNotFound)?; - let mut indexes: HashMap = HashMap::new(); - + let index_templates = { + let table = arena + .transaction() + .table(arena.table_cache(), self.table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + table + .indexes() + .map(|index_meta| { + let index_meta = plan_arena.index(*index_meta); + Ok(( + index_meta.id, + index_meta.ty, + index_meta.column_exprs(table, plan_arena)?, + )) + }) + .collect::, DatabaseError>>()? + }; let mut deleted_count = 0; - while arena.next_tuple(input)? { + while arena.next_tuple(input, plan_arena)? { let tuple = arena.result_tuple().clone(); - for index_meta in table.indexes() { - if let Some(Value { exprs, values, .. }) = indexes.get_mut(&index_meta.id) { - let Some(data_value) = - DataValue::values_to_tuple(Projection::projection(&tuple, exprs)?) - else { - continue; - }; - values.push(data_value); - } else { - let mut values = Vec::with_capacity(table.indexes().len()); - let exprs = index_meta.column_exprs(table)?; - let Some(data_value) = - DataValue::values_to_tuple(Projection::projection(&tuple, &exprs)?) - else { - continue; - }; - values.push(data_value); - - indexes.insert( - index_meta.id, - Value { - exprs, - values, - index_ty: index_meta.ty, - }, - ); - } - } if let Some(tuple_id) = &tuple.pk { - for ( - index_id, - Value { - values, index_ty, .. - }, - ) in indexes.iter_mut() - { - for value in values { - arena.transaction_mut().del_index( + for (index_id, index_ty, exprs) in index_templates.iter() { + with_projection_tmp_value(arena, Some(&tuple), exprs, |arena, value| { + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + transaction.del_index( + table_codec, &self.table_name, &Index::new(*index_id, value, *index_ty), tuple_id, - )?; - } + ) + })?; } - arena - .transaction_mut() - .remove_tuple(&self.table_name, tuple_id)?; + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + transaction.remove_tuple(table_codec, &self.table_name, tuple_id)?; deleted_count += 1; } } @@ -135,9 +120,3 @@ impl Delete { Ok(()) } } - -struct Value { - exprs: Vec, - values: Vec, - index_ty: IndexType, -} diff --git a/src/execution/dml/insert.rs b/src/execution/dml/insert.rs index 9b89f6b7..015d4823 100644 --- a/src/execution/dml/insert.rs +++ b/src/execution/dml/insert.rs @@ -12,16 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::{ColumnCatalog, TableName}; +use crate::catalog::TableName; use crate::errors::DatabaseError; -use crate::execution::dql::projection::Projection; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + build_read, with_projection_tmp_value, ExecArena, ExecId, ExecNode, ExecutionContext, + ExecutorNode, WriteExecutor, +}; use crate::planner::operator::insert::InsertOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::index::Index; -use crate::types::tuple::SchemaRef; -use crate::types::tuple::Tuple; +use crate::types::tuple::{Schema, Tuple}; use crate::types::tuple_builder::TupleBuilder; use crate::types::value::DataValue; use crate::types::ColumnId; @@ -30,7 +31,7 @@ use std::collections::HashMap; pub struct Insert { table_name: TableName, - input_schema: SchemaRef, + input_schema: Schema, input_plan: LogicalPlan, input: Option, is_overwrite: bool, @@ -45,12 +46,12 @@ impl From<(InsertOperator, LogicalPlan)> for Insert { is_overwrite, is_mapping_by_name, }, - mut input, + input, ): (InsertOperator, LogicalPlan), ) -> Self { Insert { table_name, - input_schema: input.output_schema().clone(), + input_schema: Default::default(), input_plan: input, input: None, is_overwrite, @@ -59,106 +60,124 @@ impl From<(InsertOperator, LogicalPlan)> for Insert { } } -#[derive(Debug, Eq, PartialEq, Hash)] -enum MappingKey<'a> { - Name(&'a str), - Id(Option), -} - -impl ColumnCatalog { - fn key(&self, is_mapping_by_name: bool) -> MappingKey<'_> { - if is_mapping_by_name { - MappingKey::Name(self.name()) - } else { - MappingKey::Id(self.id()) - } - } -} - impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Insert { + type Input = Self; + fn into_executor( - mut self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - self.input = Some(build_read( + let mut executor = input; + executor.input_schema = executor.input_plan.take_schema(plan_arena); + executor.input = Some(build_read( arena, - self.input_plan.take(), + plan_arena, + executor.input_plan.take(), cache, transaction, )); - arena.push(ExecNode::Insert(self)) + arena.push(ExecNode::Insert(executor)) } } +#[derive(Debug, Eq, PartialEq, Hash)] +enum MappingKey<'a> { + Name(&'a str), + Id(Option), +} + impl Insert { - pub(crate) fn next_tuple<'a, T: Transaction>( + fn column_key<'a>( + column: &'a crate::catalog::ColumnCatalog, + is_mapping_by_name: bool, + ) -> MappingKey<'a> { + if is_mapping_by_name { + MappingKey::Name(column.name()) + } else { + MappingKey::Id(column.id()) + } + } +} + +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Insert { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(input) = self.input.take() else { arena.finish(); return Ok(()); }; - if let Some(table_snapshot) = arena - .transaction() - .table(arena.table_cache(), self.table_name.clone())? - .map(|table| table.dml_snapshot()) - .transpose()? - { + let table_cache = arena.context().table_cache(); + let transaction = arena.transaction(); + let table_snapshot = { + transaction + .table(table_cache, self.table_name.clone())? + .map(|table| table.dml_snapshot(plan_arena)) + .transpose()? + }; + if let Some(table_snapshot) = table_snapshot { if table_snapshot.primary_key_indices.is_empty() { return Err(DatabaseError::not_null()); } let serializers = table_snapshot - .schema_ref + .columns .iter() - .map(|column| column.datatype().serializable()) + .map(|column| plan_arena.column(*column).datatype().serializable()) .collect_vec(); let mut inserted_count = 0; - while arena.next_tuple(input)? { + while arena.next_tuple(input, plan_arena)? { let values = arena.result_tuple().values.clone(); let mut tuple_map = HashMap::new(); for (i, value) in values.into_iter().enumerate() { - tuple_map.insert(self.input_schema[i].key(self.is_mapping_by_name), value); + let column = plan_arena.column(self.input_schema[i]); + tuple_map.insert(Self::column_key(column, self.is_mapping_by_name), value); } let mut values = Vec::with_capacity(table_snapshot.columns_len); - for col in table_snapshot.schema_ref.iter() { + for column in table_snapshot.columns.iter() { + let column = plan_arena.column(*column); let mut value = { - let mut value = tuple_map.remove(&col.key(self.is_mapping_by_name)); + let mut value = + tuple_map.remove(&Self::column_key(column, self.is_mapping_by_name)); if value.is_none() { - value = col.default_value()?; + value = column.default_value()?; } value.unwrap_or(DataValue::Null) }; - value = value.cast(col.datatype())?; - value.check_len(col.datatype())?; - if value.is_null() && !col.nullable() { - return Err(DatabaseError::not_null_column(col.name().to_string())); + value = value.cast(column.datatype())?; + value.check_len(column.datatype())?; + if value.is_null() && !column.nullable() { + return Err(DatabaseError::not_null_column(column.name().to_string())); } values.push(value) } - let pk = Tuple::primary_projection(&table_snapshot.primary_key_indices, &values); + let pk = Tuple::primary_projection(table_snapshot.primary_key_indices, &values); let tuple = Tuple::new(Some(pk), values); for (index_meta, exprs) in table_snapshot.index_metas.iter() { - let values = Projection::projection(&tuple, exprs)?; - let Some(value) = DataValue::values_to_tuple(values) else { - continue; - }; + let index_meta = plan_arena.index(*index_meta); let tuple_id = tuple.pk.as_ref().ok_or(DatabaseError::PrimaryKeyNotFound)?; - let index = Index::new(index_meta.id, &value, index_meta.ty); - arena - .transaction_mut() - .add_index(&self.table_name, index, tuple_id)?; + with_projection_tmp_value(arena, Some(&tuple), exprs, |arena, value| { + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + let index = Index::new(index_meta.id, value, index_meta.ty); + transaction.add_index(table_codec, &self.table_name, index, tuple_id) + })?; } - arena.transaction_mut().append_tuple( + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + transaction.append_tuple( + table_codec, &self.table_name, tuple, &serializers, diff --git a/src/execution/dml/mod.rs b/src/execution/dml/mod.rs index 69e64a46..98b07283 100644 --- a/src/execution/dml/mod.rs +++ b/src/execution/dml/mod.rs @@ -13,7 +13,9 @@ // limitations under the License. pub(crate) mod analyze; +#[cfg(feature = "copy")] pub(crate) mod copy_from_file; +#[cfg(feature = "copy")] pub(crate) mod copy_to_file; pub(crate) mod delete; pub(crate) mod insert; diff --git a/src/execution/dml/update.rs b/src/execution/dml/update.rs index 519b91a9..3dfedbf5 100644 --- a/src/execution/dml/update.rs +++ b/src/execution/dml/update.rs @@ -14,24 +14,24 @@ use crate::catalog::{ColumnRef, TableName}; use crate::errors::DatabaseError; -use crate::execution::dql::projection::Projection; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, WriteExecutor}; +use crate::execution::{ + build_read, with_projection_tmp_value, ExecArena, ExecId, ExecNode, ExecutionContext, + ExecutorNode, WriteExecutor, +}; use crate::expression::ScalarExpression; use crate::planner::operator::update::UpdateOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::index::Index; -use crate::types::tuple::SchemaRef; -use crate::types::tuple::Tuple; +use crate::types::tuple::{Schema, Tuple}; use crate::types::tuple_builder::TupleBuilder; -use crate::types::value::DataValue; use itertools::Itertools; use std::collections::HashMap; pub struct Update { table_name: TableName, value_exprs: Vec<(ColumnRef, ScalarExpression)>, - input_schema: SchemaRef, + input_schema: Schema, input_plan: LogicalPlan, input: Option, } @@ -43,13 +43,13 @@ impl From<(UpdateOperator, LogicalPlan)> for Update { table_name, value_exprs, }, - mut input, + input, ): (UpdateOperator, LogicalPlan), ) -> Self { Update { table_name, value_exprs, - input_schema: input.output_schema().clone(), + input_schema: Default::default(), input_plan: input, input: None, } @@ -57,26 +57,33 @@ impl From<(UpdateOperator, LogicalPlan)> for Update { } impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Update { + type Input = Self; + fn into_executor( - mut self, + input: Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - self.input = Some(build_read( + let mut executor = input; + executor.input_schema = executor.input_plan.take_schema(plan_arena); + executor.input = Some(build_read( arena, - self.input_plan.take(), + plan_arena, + executor.input_plan.take(), cache, transaction, )); - arena.push(ExecNode::Update(self)) + arena.push(ExecNode::Update(executor)) } } -impl Update { - pub(crate) fn next_tuple<'a, T: Transaction>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Update { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(input) = self.input.take() else { arena.finish(); @@ -85,24 +92,27 @@ impl Update { let mut exprs_map = HashMap::with_capacity(self.value_exprs.len()); for (column, expr) in self.value_exprs.drain(..) { - exprs_map.insert(column.id(), expr); + exprs_map.insert(plan_arena.column(column).id(), expr); } - if let Some(table_snapshot) = arena - .transaction() - .table(arena.table_cache(), self.table_name.clone())? - .map(|table| table.dml_snapshot()) - .transpose()? - { + let table_cache = arena.context().table_cache(); + let transaction = arena.transaction(); + let table_snapshot = { + transaction + .table(table_cache, self.table_name.clone())? + .map(|table| table.dml_snapshot(plan_arena)) + .transpose()? + }; + if let Some(table_snapshot) = table_snapshot { let serializers = self .input_schema .iter() - .map(|column| column.datatype().serializable()) + .map(|column| plan_arena.column(*column).datatype().serializable()) .collect_vec(); let mut updated_count = 0; - while arena.next_tuple(input)? { + while arena.next_tuple(input, plan_arena)? { let mut tuple = arena.result_tuple().clone(); let mut is_overwrite = true; @@ -110,46 +120,47 @@ impl Update { continue; }; for (index_meta, exprs) in table_snapshot.index_metas.iter() { - let values = Projection::projection(&tuple, exprs)?; - let Some(value) = DataValue::values_to_tuple(values) else { - continue; - }; - let index = Index::new(index_meta.id, &value, index_meta.ty); - arena - .transaction_mut() - .del_index(&self.table_name, &index, &old_pk)?; + let index_meta = plan_arena.index(*index_meta); + with_projection_tmp_value(arena, Some(&tuple), exprs, |arena, value| { + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + let index = Index::new(index_meta.id, value, index_meta.ty); + transaction.del_index(table_codec, &self.table_name, &index, &old_pk) + })?; } for (i, column) in self.input_schema.iter().enumerate() { - if let Some(expr) = exprs_map.get(&column.id()) { + if let Some(expr) = exprs_map.get(&plan_arena.column(*column).id()) { let value = expr.eval(Some(&tuple))?; tuple.values[i] = value; } } tuple.pk = Some(Tuple::primary_projection( - &table_snapshot.primary_key_indices, + table_snapshot.primary_key_indices, &tuple.values, )); let new_pk = tuple.pk.as_ref().ok_or(DatabaseError::PrimaryKeyNotFound)?; if new_pk != &old_pk { - arena - .transaction_mut() - .remove_tuple(&self.table_name, &old_pk)?; + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + transaction.remove_tuple(table_codec, &self.table_name, &old_pk)?; is_overwrite = false; } for (index_meta, exprs) in table_snapshot.index_metas.iter() { - let values = Projection::projection(&tuple, exprs)?; - let Some(value) = DataValue::values_to_tuple(values) else { - continue; - }; - let index = Index::new(index_meta.id, &value, index_meta.ty); - arena - .transaction_mut() - .add_index(&self.table_name, index, new_pk)?; + let index_meta = plan_arena.index(*index_meta); + with_projection_tmp_value(arena, Some(&tuple), exprs, |arena, value| { + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + let index = Index::new(index_meta.id, value, index_meta.ty); + transaction.add_index(table_codec, &self.table_name, index, new_pk) + })?; } - arena.transaction_mut().append_tuple( + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec_mut(); + transaction.append_tuple( + table_codec, &self.table_name, tuple, &serializers, diff --git a/src/execution/dql/aggregate/avg.rs b/src/execution/dql/aggregate/avg.rs index 400a2fa3..a169fd20 100644 --- a/src/execution/dql/aggregate/avg.rs +++ b/src/execution/dql/aggregate/avg.rs @@ -50,20 +50,21 @@ impl Accumulator for AvgAccumulator { Ok(()) } - fn evaluate(&self) -> Result { - let Some(acc) = &self.inner else { + fn evaluate(self: Box) -> Result { + let Self { inner, count } = *self; + let Some(acc) = inner else { return Ok(DataValue::Null); }; - let mut value = acc.evaluate()?; + let mut value = acc.into_result(); let value_ty = value.logical_type(); - if self.count == 0 { + if count == 0 { return Ok(DataValue::Null); } let quantity = if value_ty.is_signed_numeric() { - DataValue::Int64(self.count as i64) + DataValue::Int64(count as i64) } else { - DataValue::UInt32(self.count as u32) + DataValue::UInt32(count as u32) }; let quantity_ty = quantity.logical_type(); diff --git a/src/execution/dql/aggregate/count.rs b/src/execution/dql/aggregate/count.rs index 8a18f2a2..cdc14a8c 100644 --- a/src/execution/dql/aggregate/count.rs +++ b/src/execution/dql/aggregate/count.rs @@ -37,7 +37,7 @@ impl Accumulator for CountAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(self: Box) -> Result { Ok(DataValue::Int32(self.result)) } } @@ -63,7 +63,7 @@ impl Accumulator for DistinctCountAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(self: Box) -> Result { Ok(DataValue::Int32(self.distinct_values.len() as i32)) } } diff --git a/src/execution/dql/aggregate/hash_agg.rs b/src/execution/dql/aggregate/hash_agg.rs index 9df8c310..dc4ce73c 100644 --- a/src/execution/dql/aggregate/hash_agg.rs +++ b/src/execution/dql/aggregate/hash_agg.rs @@ -14,7 +14,9 @@ use crate::errors::DatabaseError; use crate::execution::dql::aggregate::{create_accumulators, Accumulator}; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::LogicalPlan; @@ -33,7 +35,7 @@ pub struct HashAggExecutor { output: Option, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for HashAggExecutor { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for HashAggExecutor { type Input = (AggregateOperator, LogicalPlan); fn into_executor( @@ -46,10 +48,11 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for HashAggExecutor { input, ): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let input = build_read(arena, input, cache, transaction); + let input = build_read(arena, plan_arena, input, cache, transaction); arena.push(ExecNode::HashAgg(HashAggExecutor { agg_calls, groupby_exprs, @@ -57,13 +60,19 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for HashAggExecutor { output: None, })) } +} - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for HashAggExecutor { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { if self.output.is_none() { let mut group_hash_accs: HashMap, Vec>> = HashMap::new(); - while arena.next_tuple(self.input)? { + while arena.next_tuple(self.input, plan_arena)? { let tuple = arena.result_tuple(); let group_keys = self .groupby_exprs @@ -105,7 +114,7 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for HashAggExecutor { output.values.clear(); output.values.reserve(accs.len() + group_keys.len()); - for acc in accs.iter() { + for acc in accs { output.values.push(acc.evaluate()?); } output.values.extend(group_keys); @@ -116,7 +125,7 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for HashAggExecutor { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::errors::DatabaseError; use crate::execution::dql::aggregate::hash_agg::HashAggExecutor; use crate::execution::dql::test::build_integers; @@ -130,35 +139,34 @@ mod test { use crate::planner::operator::values::ValuesOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; - use crate::storage::rocksdb::{RocksStorage, RocksTransaction}; + use crate::storage::rocksdb::RocksStorage; use crate::storage::Storage; use crate::types::value::DataValue; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; use itertools::Itertools; - use std::hash::RandomState; - use std::sync::Arc; use tempfile::TempDir; #[test] fn test_hash_agg() -> Result<(), DatabaseError> { - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path()).unwrap(); - let mut transaction = storage.transaction()?; + let transaction = storage.transaction()?; let desc = ColumnDesc::new(LogicalType::Integer, None, false, None)?; + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); - let t1_schema = Arc::new(vec![ - ColumnRef::from(ColumnCatalog::new("c1".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c2".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c3".to_string(), true, desc.clone())), - ]); + let t1_schema = vec![ + plan_arena.alloc_column(ColumnCatalog::new("c1".to_string(), true, desc.clone())), + plan_arena.alloc_column(ColumnCatalog::new("c2".to_string(), true, desc.clone())), + plan_arena.alloc_column(ColumnCatalog::new("c3".to_string(), true, desc.clone())), + ]; - let input = LogicalPlan { - operator: Operator::Values(ValuesOperator { + let input = LogicalPlan::new( + Operator::Values(ValuesOperator { rows: vec![ vec![ DataValue::Int32(0), @@ -183,17 +191,15 @@ mod test { ], schema_ref: t1_schema.clone(), }), - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }; + Childrens::None, + ); let plan = LogicalPlan::new( Operator::Aggregate(AggregateOperator { - groupby_exprs: vec![ScalarExpression::column_expr(t1_schema[0].clone(), 0)], + groupby_exprs: vec![ScalarExpression::column_expr(t1_schema[0], 0)], agg_calls: vec![ScalarExpression::AggCall { distinct: false, kind: AggKind::Sum, - args: vec![ScalarExpression::column_expr(t1_schema[1].clone(), 1)], + args: vec![ScalarExpression::column_expr(t1_schema[1], 1)], ty: LogicalType::Integer, }], is_distinct: false, @@ -210,15 +216,16 @@ mod test { .build(); let plan = pipeline .instantiate(plan) - .find_best::(None)?; + .find_best(None, &mut plan_arena)?; let Operator::Aggregate(op) = plan.operator else { unreachable!() }; let tuples = try_collect(execute_input::<_, HashAggExecutor>( (op, plan.childrens.pop_only()), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; assert_eq!(tuples.len(), 2); diff --git a/src/execution/dql/aggregate/min_max.rs b/src/execution/dql/aggregate/min_max.rs index 49f9a38a..812f5b63 100644 --- a/src/execution/dql/aggregate/min_max.rs +++ b/src/execution/dql/aggregate/min_max.rs @@ -55,7 +55,7 @@ impl Accumulator for MinMaxAccumulator { Ok(()) } - fn evaluate(&self) -> Result { - Ok(self.inner.clone().unwrap_or(DataValue::Null)) + fn evaluate(self: Box) -> Result { + Ok(self.inner.unwrap_or(DataValue::Null)) } } diff --git a/src/execution/dql/aggregate/mod.rs b/src/execution/dql/aggregate/mod.rs index 44a9bcb1..3153d361 100644 --- a/src/execution/dql/aggregate/mod.rs +++ b/src/execution/dql/aggregate/mod.rs @@ -39,7 +39,7 @@ pub trait Accumulator { fn update_value(&mut self, value: &DataValue) -> Result<(), DatabaseError>; /// returns its value based on its current state. - fn evaluate(&self) -> Result; + fn evaluate(self: Box) -> Result; } fn create_accumulator(expr: &ScalarExpression) -> Result, DatabaseError> { diff --git a/src/execution/dql/aggregate/simple_agg.rs b/src/execution/dql/aggregate/simple_agg.rs index 39ac775c..6986e826 100644 --- a/src/execution/dql/aggregate/simple_agg.rs +++ b/src/execution/dql/aggregate/simple_agg.rs @@ -14,7 +14,9 @@ use crate::errors::DatabaseError; use crate::execution::dql::aggregate::create_accumulators; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::LogicalPlan; @@ -25,24 +27,31 @@ pub struct SimpleAggExecutor { returned: bool, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for SimpleAggExecutor { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for SimpleAggExecutor { type Input = (AggregateOperator, LogicalPlan); fn into_executor( (AggregateOperator { agg_calls, .. }, input): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let input = build_read(arena, input, cache, transaction); + let input = build_read(arena, plan_arena, input, cache, transaction); arena.push(ExecNode::SimpleAgg(SimpleAggExecutor { agg_calls, input, returned: false, })) } +} - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for SimpleAggExecutor { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { if self.returned { arena.finish(); return Ok(()); @@ -50,7 +59,7 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for SimpleAggExecutor { let mut accs = create_accumulators(&self.agg_calls)?; - while arena.next_tuple(self.input)? { + while arena.next_tuple(self.input, plan_arena)? { let tuple = arena.result_tuple(); for (acc, expr) in accs.iter_mut().zip(self.agg_calls.iter()) { let ScalarExpression::AggCall { args, .. } = expr else { diff --git a/src/execution/dql/aggregate/stream_distinct.rs b/src/execution/dql/aggregate/stream_distinct.rs index bd83a635..6aa57549 100644 --- a/src/execution/dql/aggregate/stream_distinct.rs +++ b/src/execution/dql/aggregate/stream_distinct.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::expression::ScalarExpression; use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::LogicalPlan; @@ -29,16 +31,17 @@ pub struct StreamDistinctExecutor { scratch: Tuple, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for StreamDistinctExecutor { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for StreamDistinctExecutor { type Input = (AggregateOperator, LogicalPlan); fn into_executor( (op, input): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let input = build_read(arena, input, cache, transaction); + let input = build_read(arena, plan_arena, input, cache, transaction); arena.push(ExecNode::StreamDistinct(StreamDistinctExecutor { groupby_exprs: op.groupby_exprs, input, @@ -46,10 +49,16 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for StreamDistinctExecutor { scratch: Tuple::default(), })) } +} - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for StreamDistinctExecutor { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { loop { - if !arena.next_tuple(self.input)? { + if !arena.next_tuple(self.input, plan_arena)? { arena.finish(); return Ok(()); } @@ -75,7 +84,7 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for StreamDistinctExecutor { #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::errors::DatabaseError; use crate::execution::dql::aggregate::stream_distinct::StreamDistinctExecutor; use crate::execution::{execute_input, try_collect}; @@ -87,30 +96,27 @@ mod tests { use crate::planner::operator::values::ValuesOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; - use crate::storage::rocksdb::{RocksStorage, RocksTransaction}; + use crate::storage::rocksdb::RocksStorage; use crate::storage::{StatisticsMetaCache, Storage, TableCache, ViewCache}; use crate::types::value::DataValue; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; use itertools::Itertools; - use std::hash::RandomState; - use std::sync::Arc; use tempfile::TempDir; #[allow(clippy::type_complexity)] fn build_test_storage() -> Result< ( - Arc, - Arc, - Arc, + TableCache, + ViewCache, + StatisticsMetaCache, TempDir, RocksStorage, ), DatabaseError, > { - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; @@ -118,7 +124,10 @@ mod tests { Ok((table_cache, view_cache, meta_cache, temp_dir, storage)) } - fn optimize_exprs(plan: LogicalPlan) -> Result { + fn optimize_exprs( + plan: LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { HepOptimizerPipeline::builder() .before_batch( "Expression Remapper".to_string(), @@ -127,17 +136,16 @@ mod tests { ) .build() .instantiate(plan) - .find_best::(None) + .find_best(None, arena) } #[test] fn stream_distinct_single_column_sorted() -> Result<(), DatabaseError> { let desc = ColumnDesc::new(LogicalType::Integer, None, false, None)?; - let schema_ref = Arc::new(vec![ColumnRef::from(ColumnCatalog::new( - "c1".to_string(), - true, - desc, - ))]); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let schema_ref = + vec![plan_arena.alloc_column(ColumnCatalog::new("c1".to_string(), true, desc))]; let input = LogicalPlan::new( Operator::Values(ValuesOperator { @@ -153,22 +161,23 @@ mod tests { Childrens::None, ); let agg = AggregateOperator { - groupby_exprs: vec![ScalarExpression::column_expr(schema_ref[0].clone(), 0)], + groupby_exprs: vec![ScalarExpression::column_expr(schema_ref[0], 0)], agg_calls: vec![], is_distinct: true, }; let plan = LogicalPlan::new(Operator::Aggregate(agg), Childrens::Only(Box::new(input))); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Aggregate(agg) = plan.operator else { unreachable!() }; let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; + let transaction = storage.transaction()?; let tuples = try_collect(execute_input::<_, StreamDistinctExecutor>( (agg, plan.childrens.pop_only()), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; let actual = tuples @@ -184,10 +193,12 @@ mod tests { #[test] fn stream_distinct_multi_column_sorted() -> Result<(), DatabaseError> { let desc = ColumnDesc::new(LogicalType::Integer, None, false, None)?; - let schema_ref = Arc::new(vec![ - ColumnRef::from(ColumnCatalog::new("c1".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c2".to_string(), true, desc)), - ]); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let schema_ref = vec![ + plan_arena.alloc_column(ColumnCatalog::new("c1".to_string(), true, desc.clone())), + plan_arena.alloc_column(ColumnCatalog::new("c2".to_string(), true, desc)), + ]; let input = LogicalPlan::new( Operator::Values(ValuesOperator { @@ -204,24 +215,25 @@ mod tests { ); let agg = AggregateOperator { groupby_exprs: vec![ - ScalarExpression::column_expr(schema_ref[0].clone(), 0), - ScalarExpression::column_expr(schema_ref[1].clone(), 1), + ScalarExpression::column_expr(schema_ref[0], 0), + ScalarExpression::column_expr(schema_ref[1], 1), ], agg_calls: vec![], is_distinct: true, }; let plan = LogicalPlan::new(Operator::Aggregate(agg), Childrens::Only(Box::new(input))); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Aggregate(agg) = plan.operator else { unreachable!() }; let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; + let transaction = storage.transaction()?; let tuples = try_collect(execute_input::<_, StreamDistinctExecutor>( (agg, plan.childrens.pop_only()), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; let actual = tuples.into_iter().map(|tuple| tuple.values).collect_vec(); diff --git a/src/execution/dql/aggregate/sum.rs b/src/execution/dql/aggregate/sum.rs index 7bb6375c..a86f6dd0 100644 --- a/src/execution/dql/aggregate/sum.rs +++ b/src/execution/dql/aggregate/sum.rs @@ -15,7 +15,7 @@ use crate::errors::DatabaseError; use crate::execution::dql::aggregate::Accumulator; use crate::expression::BinaryOperator; -use crate::types::evaluator::{binary_create, BinaryEvaluatorBox}; +use crate::types::evaluator::{binary_create, BinaryEvaluatorRef}; use crate::types::value::DataValue; use crate::types::LogicalType; use ahash::RandomState; @@ -24,7 +24,7 @@ use std::collections::HashSet; pub struct SumAccumulator { result: DataValue, - evaluator: BinaryEvaluatorBox, + evaluator: BinaryEvaluatorRef, } impl SumAccumulator { @@ -36,6 +36,10 @@ impl SumAccumulator { evaluator: binary_create(ty, BinaryOperator::Plus)?, }) } + + pub(super) fn into_result(self) -> DataValue { + self.result + } } impl Accumulator for SumAccumulator { @@ -51,8 +55,8 @@ impl Accumulator for SumAccumulator { Ok(()) } - fn evaluate(&self) -> Result { - Ok(self.result.clone()) + fn evaluate(self: Box) -> Result { + Ok(self.into_result()) } } @@ -80,7 +84,7 @@ impl Accumulator for DistinctSumAccumulator { Ok(()) } - fn evaluate(&self) -> Result { - self.inner.evaluate() + fn evaluate(self: Box) -> Result { + Ok(self.inner.into_result()) } } diff --git a/src/execution/dql/describe.rs b/src/execution/dql/describe.rs index c1732a59..ba0e227b 100644 --- a/src/execution/dql/describe.rs +++ b/src/execution/dql/describe.rs @@ -14,7 +14,7 @@ use crate::catalog::{ColumnCatalog, ColumnRef, TableName}; use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor}; use crate::planner::operator::describe::DescribeOperator; use crate::storage::Transaction; use crate::types::value::{DataValue, Utf8Type}; @@ -56,76 +56,75 @@ impl From for Describe { } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Describe { - fn into_executor( - self, - arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, - ) -> ExecId { - arena.push(ExecNode::Describe(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Describe { - type Input = DescribeOperator; + type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::Describe(Describe::from(input))) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - Describe::next_tuple(self, arena) + let executor = input; + arena.push(ExecNode::Describe(executor)) } } -impl Describe { - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Describe { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { if self.columns.is_none() { let table = arena - .transaction_mut() + .transaction() .table(arena.table_cache(), self.table_name.clone())? .ok_or(DatabaseError::TableNotFound)?; - self.columns = Some(table.columns().cloned().collect()); + self.columns = Some(table.columns().copied().collect()); } - let Some(column) = self + let Some(column_ref) = self .columns .as_ref() .and_then(|columns| columns.get(self.cursor)) - .cloned() + .copied() else { arena.finish(); return Ok(()); }; self.cursor += 1; + let column = plan_arena.column(column_ref); + let default = describe_default(column, plan_arena); + let mapping = column_ref.to_string(); let output = arena.result_tuple_mut(); output.pk = None; output.values.clear(); - fill_describe_row(&mut output.values, &column); + fill_describe_row(&mut output.values, column, default, mapping); arena.resume(); Ok(()) } } -fn fill_describe_row(values: &mut Vec, column: &ColumnCatalog) { - let datatype = column.datatype(); - let default = column +fn describe_default(column: &ColumnCatalog, arena: &crate::planner::PlanArena) -> String { + column .desc() .default .as_ref() - .map(|expr| format!("{expr}")) - .unwrap_or_else(|| "null".to_string()); + .map(|expr| expr.output_name(arena)) + .unwrap_or_else(|| "null".to_string()) +} + +fn fill_describe_row( + values: &mut Vec, + column: &ColumnCatalog, + default: String, + mapping: String, +) { + let datatype = column.datatype(); values.push(DataValue::Utf8 { value: column.name().to_string(), @@ -156,6 +155,11 @@ fn fill_describe_row(values: &mut Vec, column: &ColumnCatalog) { ty: Utf8Type::Variable(None), unit: CharLengthUnits::Characters, }); + values.push(DataValue::Utf8 { + value: mapping, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }); } fn key_value(column: &ColumnCatalog) -> DataValue { diff --git a/src/execution/dql/dummy.rs b/src/execution/dql/dummy.rs index 22ec66ce..3c1d4445 100644 --- a/src/execution/dql/dummy.rs +++ b/src/execution/dql/dummy.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor}; use crate::storage::Transaction; use crate::types::tuple::Tuple; @@ -30,37 +30,25 @@ impl Default for Dummy { } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Dummy { - fn into_executor( - self, - arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, - ) -> ExecId { - arena.push(ExecNode::Dummy(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Dummy { type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::Dummy(input)) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - Dummy::next_tuple(self, arena) + let executor = input; + arena.push(ExecNode::Dummy(executor)) } } -impl Dummy { - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Dummy { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + _: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(row) = self.row.take() else { arena.finish(); diff --git a/src/execution/dql/explain.rs b/src/execution/dql/explain.rs index dcffb297..feebadfe 100644 --- a/src/execution/dql/explain.rs +++ b/src/execution/dql/explain.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor}; use crate::planner::LogicalPlan; use crate::storage::Transaction; use crate::types::value::{DataValue, Utf8Type}; @@ -34,51 +34,37 @@ impl From for Explain { } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Explain { - fn into_executor( - self, - arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, - ) -> ExecId { - arena.push(ExecNode::Explain(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Explain { - type Input = LogicalPlan; + type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::Explain(Explain { - plan: input, - emitted: false, - })) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - Explain::next_tuple(self, arena) + let executor = input; + arena.push(ExecNode::Explain(executor)) } } -impl Explain { - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Explain { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { if self.emitted { arena.finish(); return Ok(()); } + let plan = self.plan.explain(plan_arena, 0); let output = arena.result_tuple_mut(); output.pk = None; output.values.clear(); output.values.push(DataValue::Utf8 { - value: self.plan.explain(0), + value: plan, ty: Utf8Type::Variable(None), unit: CharLengthUnits::Characters, }); diff --git a/src/execution/dql/filter.rs b/src/execution/dql/filter.rs index c414eff8..6160220a 100644 --- a/src/execution/dql/filter.rs +++ b/src/execution/dql/filter.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::expression::ScalarExpression; use crate::planner::operator::filter::FilterOperator; use crate::planner::LogicalPlan; @@ -23,22 +25,29 @@ pub struct Filter { input: ExecId, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Filter { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Filter { type Input = (FilterOperator, LogicalPlan); fn into_executor( (FilterOperator { predicate, .. }, input): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let input = build_read(arena, input, cache, transaction); + let input = build_read(arena, plan_arena, input, cache, transaction); arena.push(ExecNode::Filter(Filter { predicate, input })) } +} - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Filter { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { loop { - if !arena.next_tuple(self.input)? { + if !arena.next_tuple(self.input, plan_arena)? { arena.finish(); return Ok(()); }; diff --git a/src/execution/dql/function_scan.rs b/src/execution/dql/function_scan.rs index bcec714f..4a54f961 100644 --- a/src/execution/dql/function_scan.rs +++ b/src/execution/dql/function_scan.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor}; use crate::expression::function::table::TableFunction; use crate::planner::operator::function_scan::FunctionScanOperator; use crate::storage::Transaction; @@ -34,41 +34,29 @@ impl From for FunctionScan { } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for FunctionScan { - fn into_executor( - self, - arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, - ) -> ExecId { - arena.push(ExecNode::FunctionScan(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for FunctionScan { - type Input = FunctionScanOperator; + type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::FunctionScan(FunctionScan::from(input))) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - FunctionScan::next_tuple(self, arena) + let executor = input; + arena.push(ExecNode::FunctionScan(executor)) } } -impl FunctionScan { - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for FunctionScan { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + _: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { if self.iter.is_none() { - let TableFunction { args, inner } = &self.table_function; - self.iter = Some(inner.eval(args)?); + let TableFunction { args, catalog } = &self.table_function; + self.iter = Some(catalog.inner.eval(args)?); } let tuple = self.iter.as_mut().and_then(Iterator::next).transpose()?; diff --git a/src/execution/dql/index_scan.rs b/src/execution/dql/index_scan.rs index 87adc236..d2714acc 100644 --- a/src/execution/dql/index_scan.rs +++ b/src/execution/dql/index_scan.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor}; use crate::expression::range_detacher::Range; use crate::planner::operator::table_scan::TableScanOperator; use crate::storage::{IndexIter, IndexRanges, Iter, Transaction}; @@ -59,52 +59,26 @@ impl<'a, T: Transaction + 'a> } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for IndexScan<'a, T> { - fn into_executor( - self, - arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, - ) -> ExecId { - arena.push(ExecNode::IndexScan(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for IndexScan<'a, T> { - type Input = ( - TableScanOperator, - IndexMetaRef, - IndexLookup, - Option>, - Option>, - ); + type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::IndexScan(IndexScan::from(input))) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - IndexScan::next_tuple(self, arena) + let executor = input; + arena.push(ExecNode::IndexScan(executor)) } } -impl<'a, T: Transaction + 'a> IndexScan<'a, T> { - fn ranges_from_lookup(lookup: IndexLookup, arena: &mut ExecArena<'a, T>) -> IndexRanges { - match lookup { - IndexLookup::Static(Range::SortedRanges(ranges)) => ranges.into(), - IndexLookup::Static(range) => range.into(), - IndexLookup::Probe => match arena.pop_runtime_probe() { - RuntimeIndexProbe::Eq(value) => Range::Eq(value).into(), - RuntimeIndexProbe::Scope { min, max } => Range::Scope { min, max }.into(), - }, - } - } - - pub(crate) fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for IndexScan<'a, T> { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { if self.iter.is_none() { let Some(TableScanOperator { table_name, @@ -121,12 +95,14 @@ impl<'a, T: Transaction + 'a> IndexScan<'a, T> { self.lookup.take().expect("index scan lookup initialized"), arena, ); - self.iter = Some(arena.transaction().read_by_index( - arena.table_cache(), + let state = arena.local_state(plan_arena); + self.iter = Some(state.transaction().read_by_index( + state.context.table_cache, + state.plan_arena, table_name, limit, columns, - self.index_by.clone(), + self.index_by, ranges, with_pk, self.covered_deserializers.take(), @@ -134,11 +110,12 @@ impl<'a, T: Transaction + 'a> IndexScan<'a, T> { )?); } + let state = arena.local_state(plan_arena); if self .iter .as_mut() .expect("index scan iterator initialized") - .next_tuple_into(arena.result_tuple_mut())? + .next_tuple_into(state.table_codec, &mut state.result.tuple)? { arena.resume(); } else { @@ -147,3 +124,16 @@ impl<'a, T: Transaction + 'a> IndexScan<'a, T> { Ok(()) } } + +impl<'a, T: Transaction + 'a> IndexScan<'a, T> { + fn ranges_from_lookup(lookup: IndexLookup, arena: &mut ExecArena<'a, T>) -> IndexRanges { + match lookup { + IndexLookup::Static(Range::SortedRanges(ranges)) => ranges.into(), + IndexLookup::Static(range) => range.into(), + IndexLookup::Probe => match arena.pop_runtime_probe() { + RuntimeIndexProbe::Eq(value) => Range::Eq(value).into(), + RuntimeIndexProbe::Scope { min, max } => Range::Scope { min, max }.into(), + }, + } + } +} diff --git a/src/execution/dql/join/hash_join.rs b/src/execution/dql/join/hash_join.rs index 065e2d57..88dd4444 100644 --- a/src/execution/dql/join/hash_join.rs +++ b/src/execution/dql/join/hash_join.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::execution::dql::join::hash::full_join::FullJoinState; use crate::execution::dql::join::hash::inner_join::InnerJoinState; @@ -21,10 +20,9 @@ use crate::execution::dql::join::hash::right_join::RightJoinState; use crate::execution::dql::join::hash::{ JoinProbeState, JoinProbeStateImpl, LeftDropState, ProbeState, }; -use crate::execution::dql::join::joins_nullable; use crate::execution::dql::sort::BumpVec; use crate::execution::{ - build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor, + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, }; use crate::expression::ScalarExpression; use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; @@ -70,7 +68,7 @@ enum HashJoinState { impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for HashJoin { fn from( - (JoinOperator { on, join_type }, mut left_input, mut right_input): ( + (JoinOperator { on, join_type }, left_input, right_input): ( JoinOperator, LogicalPlan, LogicalPlan, @@ -93,27 +91,14 @@ impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for HashJoin { None }; - let (left_force_nullable, right_force_nullable) = joins_nullable(&join_type); - - let mut full_schema_ref = Vec::clone(left_input.output_schema()); - let left_schema_len = full_schema_ref.len(); - - force_nullable(&mut full_schema_ref, left_force_nullable); - full_schema_ref.extend_from_slice(right_input.output_schema()); - force_nullable( - &mut full_schema_ref[left_schema_len..], - right_force_nullable, - ); - let right_schema_len = full_schema_ref.len() - left_schema_len; - HashJoin { state: HashJoinState::Build, ty: join_type, on_left_keys, on_right_keys, filter: filter_expr, - left_schema_len, - right_schema_len, + left_schema_len: 0, + right_schema_len: 0, left_input_plan: left_input, right_input_plan: right_input, left_input: 0, @@ -124,14 +109,6 @@ impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for HashJoin { } } -fn force_nullable(schema: &mut [ColumnRef], force_nullable: bool) { - for column in schema.iter_mut() { - if let Some(new_column) = column.nullable_for_join(force_nullable) { - *column = new_column; - } - } -} - impl HashJoin { #[allow(clippy::mutable_key_type)] fn own_bump_vec(buf: BumpVec<'_, DataValue>) -> BumpVec<'static, DataValue> { @@ -165,6 +142,7 @@ impl HashJoin { fn initialize_build<'a, T: Transaction + 'a>( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { if !matches!(self.state, HashJoinState::Build) { return Ok(()); @@ -178,7 +156,7 @@ impl HashJoin { let mut build_buf = BumpVec::with_capacity_in(self.on_left_keys.len(), &self.bump); let mut build_count = 0usize; - while arena.next_tuple(self.left_input)? { + while arena.next_tuple(self.left_input, plan_arena)? { let tuple = arena.result_tuple().clone(); Self::eval_keys(&self.on_left_keys, &tuple, &mut build_buf)?; @@ -246,45 +224,49 @@ pub(crate) struct BuildState { } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for HashJoin { - fn into_executor( - mut self, - arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, - ) -> ExecId { - self.left_input = build_read(arena, self.left_input_plan.take(), cache, transaction); - self.right_input = build_read(arena, self.right_input_plan.take(), cache, transaction); - arena.push(ExecNode::HashJoin(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for HashJoin { - type Input = (JoinOperator, LogicalPlan, LogicalPlan); + type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - >::into_executor(Self::from(input), arena, cache, transaction) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - HashJoin::next_tuple(self, arena) + let mut executor = input; + let left_schema_len = executor.left_input_plan.output_schema(plan_arena).len(); + let right_schema_len = executor.right_input_plan.output_schema(plan_arena).len(); + executor.left_schema_len = left_schema_len; + executor.right_schema_len = right_schema_len; + executor.left_input = build_read( + arena, + plan_arena, + executor.left_input_plan.take(), + cache, + transaction, + ); + executor.right_input = build_read( + arena, + plan_arena, + executor.right_input_plan.take(), + cache, + transaction, + ); + arena.push(ExecNode::HashJoin(executor)) } } -impl HashJoin { - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for HashJoin { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { if let Some(err) = self.init_error.take() { return Err(err); } - self.initialize_build(arena)?; + self.initialize_build(arena, plan_arena)?; let mut state = std::mem::replace(&mut self.state, HashJoinState::End); loop { @@ -298,7 +280,7 @@ impl HashJoin { } => { let probe_finished = loop { if probe_state.is_none() { - if !arena.next_tuple(self.right_input)? { + if !arena.next_tuple(self.right_input, plan_arena)? { break true; } let tuple = arena.result_tuple().clone(); @@ -381,7 +363,7 @@ impl HashJoin { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::errors::DatabaseError; use crate::execution::dql::join::hash_join::HashJoin; use crate::execution::dql::test::build_integers; @@ -394,16 +376,16 @@ mod test { use crate::planner::operator::values::ValuesOperator; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; - use crate::storage::rocksdb::{RocksStorage, RocksTransaction}; + use crate::storage::rocksdb::RocksStorage; use crate::storage::Storage; use crate::types::value::DataValue; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; - use std::hash::RandomState; - use std::sync::Arc; use tempfile::TempDir; - fn optimize_exprs(plan: LogicalPlan) -> Result { + fn optimize_exprs( + plan: LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { HepOptimizerPipeline::builder() .before_batch( "Expression Remapper".to_string(), @@ -412,10 +394,12 @@ mod test { ) .build() .instantiate(plan) - .find_best::(None) + .find_best(None, arena) } - fn build_join_values() -> ( + fn build_join_values( + arena: &mut crate::planner::PlanArena, + ) -> ( Vec<(ScalarExpression, ScalarExpression)>, LogicalPlan, LogicalPlan, @@ -423,24 +407,24 @@ mod test { let desc = ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(); let t1_columns = vec![ - ColumnRef::from(ColumnCatalog::new("c1".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c2".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c3".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c1".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c2".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c3".to_string(), true, desc.clone())), ]; let t2_columns = vec![ - ColumnRef::from(ColumnCatalog::new("c4".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c5".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c6".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c4".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c5".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c6".to_string(), true, desc.clone())), ]; let on_keys = vec![( - ScalarExpression::column_expr(t1_columns[0].clone(), 0), - ScalarExpression::column_expr(t2_columns[0].clone(), 0), + ScalarExpression::column_expr(t1_columns[0], 0), + ScalarExpression::column_expr(t2_columns[0], 0), )]; - let values_t1 = LogicalPlan { - operator: Operator::Values(ValuesOperator { + let values_t1 = LogicalPlan::new( + Operator::Values(ValuesOperator { rows: vec![ vec![ DataValue::Int32(0), @@ -458,15 +442,13 @@ mod test { DataValue::Int32(7), ], ], - schema_ref: Arc::new(t1_columns), + schema_ref: t1_columns, }), - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }; + Childrens::None, + ); - let values_t2 = LogicalPlan { - operator: Operator::Values(ValuesOperator { + let values_t2 = LogicalPlan::new( + Operator::Values(ValuesOperator { rows: vec![ vec![ DataValue::Int32(0), @@ -489,12 +471,10 @@ mod test { DataValue::Int32(1), ], ], - schema_ref: Arc::new(t2_columns), + schema_ref: t2_columns, }), - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }; + Childrens::None, + ); (on_keys, values_t1, values_t2) } @@ -503,11 +483,13 @@ mod test { fn test_inner_join() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right) = build_join_values(); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right) = build_join_values(&mut plan_arena); let plan = LogicalPlan::new( Operator::Join(JoinOperator { @@ -522,7 +504,7 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() @@ -530,8 +512,9 @@ mod test { let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( HashJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -557,11 +540,13 @@ mod test { fn test_left_join() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right) = build_join_values(); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right) = build_join_values(&mut plan_arena); let plan = LogicalPlan::new( Operator::Join(JoinOperator { @@ -576,7 +561,7 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() @@ -586,8 +571,9 @@ mod test { let executor = HashJoin::from((op.clone(), left.clone(), right.clone())); let tuples = try_collect(crate::execution::execute( executor, - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; assert_eq!(tuples.len(), 4); @@ -617,11 +603,13 @@ mod test { fn test_right_join() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right) = build_join_values(); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right) = build_join_values(&mut plan_arena); let plan = LogicalPlan::new( Operator::Join(JoinOperator { @@ -636,7 +624,7 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() @@ -644,8 +632,9 @@ mod test { let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( HashJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -675,55 +664,50 @@ mod test { fn test_right_join_filter_only_left_columns() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); let desc = ColumnDesc::new(LogicalType::Integer, None, false, None)?; let left_columns = vec![ - ColumnRef::from(ColumnCatalog::new("k".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("v".to_string(), true, desc.clone())), + plan_arena.alloc_column(ColumnCatalog::new("k".to_string(), true, desc.clone())), + plan_arena.alloc_column(ColumnCatalog::new("v".to_string(), true, desc.clone())), ]; - let right_columns = vec![ColumnRef::from(ColumnCatalog::new( - "rk".to_string(), - true, - desc.clone(), - ))]; + let right_columns = + vec![plan_arena.alloc_column(ColumnCatalog::new("rk".to_string(), true, desc.clone()))]; let on_keys = vec![( - ScalarExpression::column_expr(left_columns[0].clone(), 0), - ScalarExpression::column_expr(right_columns[0].clone(), 0), + ScalarExpression::column_expr(left_columns[0], 0), + ScalarExpression::column_expr(right_columns[0], 0), )]; let filter_expr = ScalarExpression::Binary { op: BinaryOperator::Gt, - left_expr: Box::new(ScalarExpression::column_expr(left_columns[1].clone(), 1)), + left_expr: Box::new(ScalarExpression::column_expr(left_columns[1], 1)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, ty: LogicalType::Boolean, }; - let left = LogicalPlan { - operator: Operator::Values(ValuesOperator { + let left = LogicalPlan::new( + Operator::Values(ValuesOperator { rows: vec![ vec![DataValue::Int32(2), DataValue::Int32(0)], vec![DataValue::Int32(2), DataValue::Int32(5)], ], - schema_ref: Arc::new(left_columns), + schema_ref: left_columns, }), - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }; - let right = LogicalPlan { - operator: Operator::Values(ValuesOperator { + Childrens::None, + ); + let right = LogicalPlan::new( + Operator::Values(ValuesOperator { rows: vec![vec![DataValue::Int32(2)]], - schema_ref: Arc::new(right_columns), + schema_ref: right_columns, }), - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }; + Childrens::None, + ); let plan = LogicalPlan::new( Operator::Join(JoinOperator { @@ -739,7 +723,7 @@ mod test { }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() @@ -747,8 +731,9 @@ mod test { let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( HashJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -769,11 +754,13 @@ mod test { fn test_full_join() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right) = build_join_values(); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right) = build_join_values(&mut plan_arena); let plan = LogicalPlan::new( Operator::Join(JoinOperator { @@ -788,7 +775,7 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() @@ -796,8 +783,9 @@ mod test { let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( HashJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; diff --git a/src/execution/dql/join/nested_loop_join.rs b/src/execution/dql/join/nested_loop_join.rs index 788ba9b3..f4e348df 100644 --- a/src/execution/dql/join/nested_loop_join.rs +++ b/src/execution/dql/join/nested_loop_join.rs @@ -17,13 +17,13 @@ use crate::errors::DatabaseError; use crate::execution::{ - build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor, + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, }; use crate::expression::ScalarExpression; use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; use crate::planner::LogicalPlan; use crate::storage::Transaction; -use crate::types::tuple::{Schema, SplitTupleRef, Tuple}; +use crate::types::tuple::{SplitTupleRef, Tuple}; use crate::types::value::DataValue; use fixedbitset::FixedBitSet; use itertools::Itertools; @@ -37,26 +37,6 @@ struct EqualCondition { } impl EqualCondition { - /// Constructs a new `EqualCondition` - /// If the `on_left_keys` and `on_right_keys` are empty, it means no equivalent condition - /// Note: `on_left_keys` and `on_right_keys` are either all empty or none of them. - fn new( - on_left_keys: Vec, - on_right_keys: Vec, - left_schema: &Schema, - right_schema: &Schema, - ) -> EqualCondition { - if !on_left_keys.is_empty() && on_left_keys.len() != on_right_keys.len() { - unreachable!("Unexpected join on condition.") - } - EqualCondition { - on_left_keys, - on_right_keys, - left_len: left_schema.len(), - right_len: right_schema.len(), - } - } - /// Compare left tuple and right tuple on equivalent condition /// `left_tuple` must be from the [`NestedLoopJoin::left_input`] /// `right_tuple` must be from the [`NestedLoopJoin::right_input`] @@ -132,16 +112,18 @@ impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for NestedLoopJoin { }; let (mut left_input, mut right_input) = (left_input, right_input); - let mut left_schema = left_input.output_schema().clone(); - let mut right_schema = right_input.output_schema().clone(); if matches!(join_type, JoinType::RightOuter) { std::mem::swap(&mut left_input, &mut right_input); std::mem::swap(&mut on_left_keys, &mut on_right_keys); - std::mem::swap(&mut left_schema, &mut right_schema); } - let eq_cond = EqualCondition::new(on_left_keys, on_right_keys, &left_schema, &right_schema); + let eq_cond = EqualCondition { + on_left_keys, + on_right_keys, + left_len: 0, + right_len: 0, + }; NestedLoopJoin { left_input_plan: left_input, @@ -156,64 +138,46 @@ impl From<(JoinOperator, LogicalPlan, LogicalPlan)> for NestedLoopJoin { } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for NestedLoopJoin { - fn into_executor( - mut self, - arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, - ) -> ExecId { - self.left_input = build_read(arena, self.left_input_plan.take(), cache, transaction); - arena.push(ExecNode::NestedLoopJoin(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for NestedLoopJoin { - type Input = (JoinOperator, LogicalPlan, LogicalPlan); + type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, - ) -> ExecId { - >::into_executor(Self::from(input), arena, cache, transaction) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - NestedLoopJoin::next_tuple(self, arena) - } -} - -impl NestedLoopJoin { - fn build_right_input<'a, T: Transaction + 'a>( - &mut self, - arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let cache = ( - arena.table_cache(), - arena.view_cache(), - arena.meta_cache(), - arena.scala_functions(), - arena.table_functions(), + let mut executor = input; + let left_len = executor.left_input_plan.output_schema(plan_arena).len(); + let right_len = executor.right_input_plan.output_schema(plan_arena).len(); + executor.eq_cond.left_len = left_len; + executor.eq_cond.right_len = right_len; + executor.left_input = build_read( + arena, + plan_arena, + executor.left_input_plan.take(), + cache, + transaction, ); - let transaction = arena.transaction_mut() as *mut T; - // Fixme: Executor reset - build_read(arena, self.right_input_plan.clone(), cache, transaction) + arena.push(ExecNode::NestedLoopJoin(executor)) } +} - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for NestedLoopJoin { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let mut state = std::mem::replace(&mut self.state, NestedLoopJoinState::End); loop { match state { NestedLoopJoinState::PullLeft { right_bitmap } => { - if !arena.next_tuple(self.left_input)? { + if !arena.next_tuple(self.left_input, plan_arena)? { if matches!(self.ty, JoinType::Full) { state = NestedLoopJoinState::EmitRightUnmatched { - right_input: self.build_right_input(arena), + right_input: self.build_right_input(arena, plan_arena), right_bitmap: right_bitmap.unwrap_or_default(), right_emit_index: 0, }; @@ -228,7 +192,7 @@ impl NestedLoopJoin { state = NestedLoopJoinState::ScanRight { active_left: ActiveLeftState { left_tuple, - right_input: self.build_right_input(arena), + right_input: self.build_right_input(arena, plan_arena), right_index: 0, has_matched: false, first_matches: Vec::new(), @@ -240,7 +204,7 @@ impl NestedLoopJoin { mut active_left, mut right_bitmap, } => { - while arena.next_tuple(active_left.right_input)? { + while arena.next_tuple(active_left.right_input, plan_arena)? { let right_tuple = arena.result_tuple().clone(); let idx = active_left.right_index; active_left.right_index += 1; @@ -369,7 +333,7 @@ impl NestedLoopJoin { right_bitmap, mut right_emit_index, } => { - while arena.next_tuple(right_input)? { + while arena.next_tuple(right_input, plan_arena)? { let mut right_tuple = arena.result_tuple().clone(); let idx = right_emit_index; right_emit_index += 1; @@ -399,6 +363,25 @@ impl NestedLoopJoin { } } } +} + +impl NestedLoopJoin { + fn build_right_input<'a, T: Transaction + 'a>( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> ExecId { + let cache = arena.context(); + let transaction = arena.transaction(); + // Fixme: Executor reset + build_read( + arena, + plan_arena, + self.right_input_plan.clone(), + cache, + transaction, + ) + } /// Emit a tuple according to the join type. /// @@ -449,8 +432,8 @@ impl NestedLoopJoin { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; - use crate::db::DataBaseBuilder; + use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::db::{CatalogKind, DataBaseBuilder}; use crate::execution::dql::test::build_integers; use crate::execution::try_collect; use crate::expression::BinaryOperator; @@ -460,18 +443,18 @@ mod test { use crate::planner::operator::values::ValuesOperator; use crate::planner::operator::Operator; use crate::planner::Childrens; - use crate::storage::rocksdb::{RocksStorage, RocksTransaction}; + use crate::storage::rocksdb::RocksStorage; use crate::storage::Storage; use crate::types::evaluator::binary_create; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; use std::borrow::Cow; use std::collections::HashSet; - use std::hash::RandomState; - use std::sync::Arc; use tempfile::TempDir; - fn optimize_exprs(plan: LogicalPlan) -> Result { + fn optimize_exprs( + plan: LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { HepOptimizerPipeline::builder() .before_batch( "Expression Remapper".to_string(), @@ -480,7 +463,7 @@ mod test { ) .build() .instantiate(plan) - .find_best::(None) + .find_best(None, arena) } fn tuple_to_strings(tuple: &Tuple) -> Vec> { @@ -496,6 +479,7 @@ mod test { } fn build_join_values( + arena: &mut crate::planner::PlanArena, eq: bool, ) -> ( Vec<(ScalarExpression, ScalarExpression)>, @@ -506,28 +490,28 @@ mod test { let desc = ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(); let t1_columns = vec![ - ColumnRef::from(ColumnCatalog::new("c1".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c2".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c3".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c1".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c2".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c3".to_string(), true, desc.clone())), ]; let t2_columns = vec![ - ColumnRef::from(ColumnCatalog::new("c4".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c5".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("c6".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c4".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c5".to_string(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c6".to_string(), true, desc.clone())), ]; let on_keys = if eq { vec![( - ScalarExpression::column_expr(t1_columns[1].clone(), 1), - ScalarExpression::column_expr(t2_columns[1].clone(), 1), + ScalarExpression::column_expr(t1_columns[1], 1), + ScalarExpression::column_expr(t2_columns[1], 1), )] } else { vec![] }; - let values_t1 = LogicalPlan { - operator: Operator::Values(ValuesOperator { + let values_t1 = LogicalPlan::new( + Operator::Values(ValuesOperator { rows: vec![ vec![ DataValue::Int32(0), @@ -550,15 +534,13 @@ mod test { DataValue::Int32(7), ], ], - schema_ref: Arc::new(t1_columns), + schema_ref: t1_columns, }), - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }; + Childrens::None, + ); - let values_t2 = LogicalPlan { - operator: Operator::Values(ValuesOperator { + let values_t2 = LogicalPlan::new( + Operator::Values(ValuesOperator { rows: vec![ vec![ DataValue::Int32(0), @@ -581,21 +563,19 @@ mod test { DataValue::Int32(1), ], ], - schema_ref: Arc::new(t2_columns), + schema_ref: t2_columns, }), - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }; + Childrens::None, + ); let filter = ScalarExpression::Binary { op: crate::expression::BinaryOperator::Gt, left_expr: Box::new(ScalarExpression::column_expr( - ColumnRef::from(ColumnCatalog::new("c1".to_owned(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c1".to_owned(), true, desc.clone())), 0, )), right_expr: Box::new(ScalarExpression::column_expr( - ColumnRef::from(ColumnCatalog::new("c4".to_owned(), true, desc.clone())), + arena.alloc_column(ColumnCatalog::new("c4".to_owned(), true, desc.clone())), 3, )), evaluator: Some( @@ -632,11 +612,13 @@ mod test { fn test_nested_inner_join() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right, filter) = build_join_values(true); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right, filter) = build_join_values(&mut plan_arena, true); let plan = LogicalPlan::new( Operator::Join(JoinOperator { on: JoinCondition::On { @@ -650,15 +632,16 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( NestedLoopJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -681,11 +664,13 @@ mod test { fn test_nested_left_out_join() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right, filter) = build_join_values(true); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right, filter) = build_join_values(&mut plan_arena, true); let plan = LogicalPlan::new( Operator::Join(JoinOperator { on: JoinCondition::On { @@ -699,15 +684,16 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( NestedLoopJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -759,11 +745,13 @@ mod test { fn test_nested_cross_join_with_on() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right, filter) = build_join_values(true); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right, filter) = build_join_values(&mut plan_arena, true); let plan = LogicalPlan::new( Operator::Join(JoinOperator { on: JoinCondition::On { @@ -777,15 +765,16 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( NestedLoopJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -808,11 +797,13 @@ mod test { fn test_nested_cross_join_without_filter() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right, _) = build_join_values(true); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right, _) = build_join_values(&mut plan_arena, true); let plan = LogicalPlan::new( Operator::Join(JoinOperator { on: JoinCondition::On { @@ -826,15 +817,16 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( NestedLoopJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -872,11 +864,13 @@ mod test { fn test_nested_cross_join_without_on() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right, _) = build_join_values(false); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right, _) = build_join_values(&mut plan_arena, false); let plan = LogicalPlan::new( Operator::Join(JoinOperator { on: JoinCondition::On { @@ -890,15 +884,16 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( NestedLoopJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -911,11 +906,13 @@ mod test { fn test_nested_right_out_join() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right, filter) = build_join_values(true); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right, filter) = build_join_values(&mut plan_arena, true); let plan = LogicalPlan::new( Operator::Join(JoinOperator { on: JoinCondition::On { @@ -929,15 +926,16 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( NestedLoopJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -984,11 +982,13 @@ mod test { fn test_nested_full_join() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let (keys, left, right, filter) = build_join_values(true); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let (keys, left, right, filter) = build_join_values(&mut plan_arena, true); let plan = LogicalPlan::new( Operator::Join(JoinOperator { on: JoinCondition::On { @@ -1002,15 +1002,16 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() }; let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( NestedLoopJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -1086,55 +1087,50 @@ mod test { fn test_nested_right_join_filter_only_left_columns() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let mut transaction = storage.transaction()?; - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let transaction = storage.transaction()?; + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); let desc = ColumnDesc::new(LogicalType::Integer, None, false, None)?; let left_columns = vec![ - ColumnRef::from(ColumnCatalog::new("k".to_string(), true, desc.clone())), - ColumnRef::from(ColumnCatalog::new("v".to_string(), true, desc.clone())), + plan_arena.alloc_column(ColumnCatalog::new("k".to_string(), true, desc.clone())), + plan_arena.alloc_column(ColumnCatalog::new("v".to_string(), true, desc.clone())), ]; - let right_columns = vec![ColumnRef::from(ColumnCatalog::new( - "rk".to_string(), - true, - desc.clone(), - ))]; + let right_columns = + vec![plan_arena.alloc_column(ColumnCatalog::new("rk".to_string(), true, desc.clone()))]; let on_keys = vec![( - ScalarExpression::column_expr(left_columns[0].clone(), 0), - ScalarExpression::column_expr(right_columns[0].clone(), 0), + ScalarExpression::column_expr(left_columns[0], 0), + ScalarExpression::column_expr(right_columns[0], 0), )]; let filter_expr = ScalarExpression::Binary { op: crate::expression::BinaryOperator::Gt, - left_expr: Box::new(ScalarExpression::column_expr(left_columns[1].clone(), 1)), + left_expr: Box::new(ScalarExpression::column_expr(left_columns[1], 1)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, ty: LogicalType::Boolean, }; - let left = LogicalPlan { - operator: Operator::Values(ValuesOperator { + let left = LogicalPlan::new( + Operator::Values(ValuesOperator { rows: vec![ vec![DataValue::Int32(2), DataValue::Int32(0)], vec![DataValue::Int32(2), DataValue::Int32(5)], ], - schema_ref: Arc::new(left_columns), + schema_ref: left_columns, }), - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }; - let right = LogicalPlan { - operator: Operator::Values(ValuesOperator { + Childrens::None, + ); + let right = LogicalPlan::new( + Operator::Values(ValuesOperator { rows: vec![vec![DataValue::Int32(2)]], - schema_ref: Arc::new(right_columns), + schema_ref: right_columns, }), - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }; + Childrens::None, + ); let plan = LogicalPlan::new( Operator::Join(JoinOperator { @@ -1149,7 +1145,7 @@ mod test { right: Box::new(right), }, ); - let plan = optimize_exprs(plan)?; + let plan = optimize_exprs(plan, &mut plan_arena)?; let Operator::Join(op) = plan.operator else { unreachable!() @@ -1157,8 +1153,9 @@ mod test { let (left, right) = plan.childrens.pop_twins(); let executor = crate::execution::execute( NestedLoopJoin::from((op, left, right)), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ); let tuples = try_collect(executor)?; @@ -1178,7 +1175,7 @@ mod test { #[test] fn test_right_join_using_binds_visible_column_to_right_side() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let db = DataBaseBuilder::path(temp_dir.path()).build_in_memory()?; + let mut db = DataBaseBuilder::path(temp_dir.path()).build_in_memory()?; let setup_sql = [ "DROP TABLE IF EXISTS str1", @@ -1190,7 +1187,16 @@ mod test { ]; for sql in setup_sql { - db.run(sql)?.done()?; + if sql.starts_with("DROP ") || sql.starts_with("CREATE ") { + db.ddl(sql)?; + if sql.starts_with("CREATE TABLE str1") { + db.load(CatalogKind::Table("str1".to_string().into()))?; + } else if sql.starts_with("CREATE TABLE str2") { + db.load(CatalogKind::Table("str2".to_string().into()))?; + } + } else { + db.run(sql)?.done()?; + } } let mut iter = db.run( diff --git a/src/execution/dql/limit.rs b/src/execution/dql/limit.rs index a7ca9336..20bdf39d 100644 --- a/src/execution/dql/limit.rs +++ b/src/execution/dql/limit.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::planner::operator::limit::LimitOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; @@ -25,16 +27,17 @@ pub struct Limit { emitted: usize, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Limit { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Limit { type Input = (LimitOperator, LogicalPlan); fn into_executor( (LimitOperator { offset, limit }, input): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let input = build_read(arena, input, cache, transaction); + let input = build_read(arena, plan_arena, input, cache, transaction); arena.push(ExecNode::Limit(Limit { offset, limit, @@ -43,8 +46,14 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Limit { emitted: 0, })) } +} - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Limit { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { let offset = self.offset.unwrap_or(0); let limit = self.limit.unwrap_or(usize::MAX); @@ -54,7 +63,7 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Limit { } loop { - if !arena.next_tuple(self.input)? { + if !arena.next_tuple(self.input, plan_arena)? { arena.finish(); return Ok(()); } diff --git a/src/execution/dql/mark_apply.rs b/src/execution/dql/mark_apply.rs index 8dfa8342..de4eaf5f 100644 --- a/src/execution/dql/mark_apply.rs +++ b/src/execution/dql/mark_apply.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::planner::operator::mark_apply::{MarkApplyKind, MarkApplyOperator, MarkApplyQuantifier}; use crate::planner::LogicalPlan; use crate::storage::Transaction; @@ -37,16 +39,17 @@ pub struct MarkApply { left_tuple: Tuple, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for MarkApply { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for MarkApply { type Input = (MarkApplyOperator, LogicalPlan, LogicalPlan); fn into_executor( (op, left_input, right_input): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let left_input = build_read(arena, left_input, cache, transaction); + let left_input = build_read(arena, plan_arena, left_input, cache, transaction); arena.push(ExecNode::MarkApply(Self { op, right_input_plan: right_input, @@ -54,15 +57,21 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for MarkApply { left_tuple: Tuple::default(), })) } +} - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - if !arena.next_tuple(self.left_input)? { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for MarkApply { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { + if !arena.next_tuple(self.left_input, plan_arena)? { arena.finish(); return Ok(()); } self.left_tuple = mem::take(arena.result_tuple_mut()); - let marker = self.mark_value(arena)?; + let marker = self.mark_value(arena, plan_arena)?; arena.produce_tuple(mem::take(&mut self.left_tuple)); arena.result_tuple_mut().values.push(marker); @@ -94,8 +103,13 @@ impl MarkApply { fn with_right_input<'a, T: Transaction + 'a, R>( &self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, param_value: Option, - f: impl FnOnce(&mut ExecArena<'a, T>, ExecId) -> Result, + f: impl FnOnce( + &mut ExecArena<'a, T>, + &mut crate::planner::PlanArena<'a>, + ExecId, + ) -> Result, ) -> Result { let runtime_probe = self.runtime_probe_for(param_value); let depth_before = arena.runtime_probe_depth(); @@ -103,17 +117,17 @@ impl MarkApply { arena.push_runtime_probe(runtime_probe); } - let cache = ( - arena.table_cache(), - arena.view_cache(), - arena.meta_cache(), - arena.scala_functions(), - arena.table_functions(), - ); - let transaction = arena.transaction_mut() as *mut T; + let cache = arena.context(); + let transaction = arena.transaction(); let result = { - let right_input = build_read(arena, self.right_input_plan.clone(), cache, transaction); - f(arena, right_input) + let right_input = build_read( + arena, + plan_arena, + self.right_input_plan.clone(), + cache, + transaction, + ); + f(arena, plan_arena, right_input) }; let depth_after = arena.runtime_probe_depth(); @@ -138,13 +152,15 @@ impl MarkApply { fn mark_value<'a, T: Transaction + 'a>( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result { match self.op.kind { MarkApplyKind::Exists => self.with_right_input( arena, + plan_arena, self.parameterized_probe_value()?, - |arena, right_input| { - while arena.next_tuple(right_input)? { + |arena, plan_arena, right_input| { + while arena.next_tuple(right_input, plan_arena)? { let right_tuple = arena.result_tuple(); if self.exists_predicate_matched(&self.left_tuple, right_tuple)? { return Ok(DataValue::Boolean(true)); @@ -159,9 +175,10 @@ impl MarkApply { if !probe_value.is_null() { if self.with_right_input( arena, + plan_arena, Some(probe_value), - |arena, right_input| { - while arena.next_tuple(right_input)? { + |arena, plan_arena, right_input| { + while arena.next_tuple(right_input, plan_arena)? { let right_tuple = arena.result_tuple(); if self.quantified_predicate_outcome( &self.left_tuple, @@ -180,9 +197,10 @@ impl MarkApply { if self.with_right_input( arena, + plan_arena, Some(DataValue::Null), - |arena, right_input| { - while arena.next_tuple(right_input)? { + |arena, plan_arena, right_input| { + while arena.next_tuple(right_input, plan_arena)? { let right_tuple = arena.result_tuple(); if self.quantified_predicate_outcome( &self.left_tuple, @@ -203,13 +221,23 @@ impl MarkApply { } } - self.with_right_input(arena, None, |arena, right_input| { - self.scan_quantified_right_input(arena, right_input, MarkApplyQuantifier::Any) + self.with_right_input(arena, plan_arena, None, |arena, plan_arena, right_input| { + self.scan_quantified_right_input( + arena, + plan_arena, + right_input, + MarkApplyQuantifier::Any, + ) }) } MarkApplyKind::Quantified(MarkApplyQuantifier::All) => { - self.with_right_input(arena, None, |arena, right_input| { - self.scan_quantified_right_input(arena, right_input, MarkApplyQuantifier::All) + self.with_right_input(arena, plan_arena, None, |arena, plan_arena, right_input| { + self.scan_quantified_right_input( + arena, + plan_arena, + right_input, + MarkApplyQuantifier::All, + ) }) } } @@ -218,12 +246,13 @@ impl MarkApply { fn scan_quantified_right_input<'a, T: Transaction + 'a>( &self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, right_input: ExecId, quantifier: MarkApplyQuantifier, ) -> Result { let mut saw_null = false; - while arena.next_tuple(right_input)? { + while arena.next_tuple(right_input, plan_arena)? { let right_tuple = arena.result_tuple(); match self.quantified_predicate_outcome(&self.left_tuple, right_tuple)? { QuantifiedPredicateOutcome::True => { @@ -323,28 +352,24 @@ mod tests { use crate::types::index::RuntimeIndexProbe; use crate::types::tuple::Tuple; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; use std::borrow::Cow; - use std::hash::RandomState; - use std::sync::Arc; use tempfile::TempDir; fn build_values_with_schema( + arena: &mut crate::planner::PlanArena, columns: Vec<(&str, LogicalType)>, rows: Vec>, ) -> LogicalPlan { - let schema_ref = Arc::new( - columns - .into_iter() - .map(|(name, ty)| { - ColumnRef::from(ColumnCatalog::new( - name.to_string(), - true, - ColumnDesc::new(ty, None, true, None).unwrap(), - )) - }) - .collect(), - ); + let schema_ref = columns + .into_iter() + .map(|(name, ty)| { + arena.alloc_column(ColumnCatalog::new( + name.to_string(), + true, + ColumnDesc::new(ty, None, true, None).unwrap(), + )) + }) + .collect(); LogicalPlan::new( Operator::Values(ValuesOperator { rows, schema_ref }), @@ -352,23 +377,27 @@ mod tests { ) } - fn build_values(name: &str, rows: Vec>) -> LogicalPlan { - build_values_with_schema(vec![(name, LogicalType::Integer)], rows) + fn build_values( + arena: &mut crate::planner::PlanArena, + name: &str, + rows: Vec>, + ) -> LogicalPlan { + build_values_with_schema(arena, vec![(name, LogicalType::Integer)], rows) } fn build_test_storage() -> Result< ( - Arc, - Arc, - Arc, + TableCache, + ViewCache, + StatisticsMetaCache, TempDir, RocksStorage, ), DatabaseError, > { - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; @@ -376,8 +405,8 @@ mod tests { Ok((table_cache, view_cache, meta_cache, temp_dir, storage)) } - fn build_marker_column() -> ColumnRef { - ColumnRef::from(ColumnCatalog::new( + fn build_marker_column(arena: &mut crate::planner::PlanArena) -> ColumnRef { + arena.alloc_column(ColumnCatalog::new( "__exists".to_string(), true, ColumnDesc::new(LogicalType::Boolean, None, true, None).unwrap(), @@ -404,29 +433,37 @@ mod tests { #[test] fn mark_exists_apply_appends_boolean_match_column() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); let mut left = build_values( + &mut plan_arena, "left_c1", vec![vec![DataValue::Int32(1)], vec![DataValue::Int32(2)]], ); let mut right = build_values( + &mut plan_arena, "right_c1", vec![vec![DataValue::Int32(2)], vec![DataValue::Int32(3)]], ); - let left_column = left.output_schema()[0].clone(); - let right_column = right.output_schema()[0].clone(); + let left_column = left.output_schema(&mut plan_arena)[0]; + let right_column = right.output_schema(&mut plan_arena)[0]; let predicate = build_equality_predicate(left_column, 0, right_column, 1)?; let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; + let transaction = storage.transaction()?; let tuples = try_collect(execute_input::<_, MarkApply>( ( - MarkApplyOperator::new_exists(build_marker_column(), vec![predicate]), + MarkApplyOperator::new_exists( + build_marker_column(&mut plan_arena), + vec![predicate], + ), left, right, ), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; assert_eq!( @@ -447,29 +484,37 @@ mod tests { #[test] fn mark_exists_apply_treats_null_predicate_as_not_matched() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); let mut left = build_values( + &mut plan_arena, "left_c1", vec![vec![DataValue::Int32(1)], vec![DataValue::Int32(2)]], ); let mut right = build_values( + &mut plan_arena, "right_c1", vec![vec![DataValue::Null], vec![DataValue::Int32(2)]], ); - let left_column = left.output_schema()[0].clone(); - let right_column = right.output_schema()[0].clone(); + let left_column = left.output_schema(&mut plan_arena)[0]; + let right_column = right.output_schema(&mut plan_arena)[0]; let predicate = build_equality_predicate(left_column, 0, right_column, 1)?; let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; + let transaction = storage.transaction()?; let tuples = try_collect(execute_input::<_, MarkApply>( ( - MarkApplyOperator::new_exists(build_marker_column(), vec![predicate]), + MarkApplyOperator::new_exists( + build_marker_column(&mut plan_arena), + vec![predicate], + ), left, right, ), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; assert_eq!( @@ -491,7 +536,10 @@ mod tests { #[test] fn mark_exists_apply_sets_runtime_probe_before_residual_predicates() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); let mut left = build_values_with_schema( + &mut plan_arena, vec![ ("left_c1", LogicalType::Integer), ("left_flag", LogicalType::Integer), @@ -499,6 +547,7 @@ mod tests { vec![], ); let mut right = build_values_with_schema( + &mut plan_arena, vec![ ("right_c1", LogicalType::Integer), ("right_flag", LogicalType::Integer), @@ -508,25 +557,29 @@ mod tests { vec![DataValue::Int32(2), DataValue::Null], ], ); - let left_value_column = left.output_schema()[0].clone(); - let left_flag_column = left.output_schema()[1].clone(); - let right_value_column = right.output_schema()[0].clone(); - let right_flag_column = right.output_schema()[1].clone(); + let left_schema = left.output_schema(&mut plan_arena).clone(); + let right_schema = right.output_schema(&mut plan_arena).clone(); + let left_value_column = left_schema[0]; + let left_flag_column = left_schema[1]; + let right_value_column = right_schema[0]; + let right_flag_column = right_schema[1]; let probe_predicate = - build_equality_predicate(left_value_column.clone(), 0, right_value_column, 2)?; - let flag_predicate = - build_equality_predicate(left_flag_column.clone(), 1, right_flag_column, 3)?; + build_equality_predicate(left_value_column, 0, right_value_column, 2)?; + let flag_predicate = build_equality_predicate(left_flag_column, 1, right_flag_column, 3)?; let mut op = MarkApplyOperator::new_exists( - build_marker_column(), + build_marker_column(&mut plan_arena), vec![probe_predicate, flag_predicate], ); op.set_parameterized_probe(Some(ScalarExpression::column_expr(left_value_column, 0))); let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; - let mut arena = ExecArena::default(); - arena.init_context((&table_cache, &view_cache, &meta_cache), &mut transaction); + let transaction = storage.transaction()?; + let mut arena = ExecArena::new(); + arena.init_context( + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + &transaction, + ); let mut exec = MarkApply { op, @@ -535,7 +588,10 @@ mod tests { left_tuple: Tuple::new(None, vec![DataValue::Int32(2), DataValue::Int32(1)]), }; - assert_eq!(exec.mark_value(&mut arena)?, DataValue::Boolean(true)); + assert_eq!( + exec.mark_value(&mut arena, &mut plan_arena)?, + DataValue::Boolean(true) + ); assert_eq!( exec.runtime_probe_for(Some(DataValue::Int32(2))), Some(RuntimeIndexProbe::Eq(DataValue::Int32(2))) @@ -546,22 +602,32 @@ mod tests { #[test] fn mark_in_apply_sets_eq_runtime_probe_for_non_null_value() -> Result<(), DatabaseError> { - let mut left = build_values_with_schema(vec![("left_c1", LogicalType::Integer)], vec![]); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let mut left = build_values_with_schema( + &mut plan_arena, + vec![("left_c1", LogicalType::Integer)], + vec![], + ); let mut right = build_values_with_schema( + &mut plan_arena, vec![("right_c1", LogicalType::Integer)], vec![vec![DataValue::Int32(2)]], ); - let left_value_column = left.output_schema()[0].clone(); - let right_value_column = right.output_schema()[0].clone(); - let predicate = - build_equality_predicate(left_value_column.clone(), 0, right_value_column, 1)?; - let mut op = MarkApplyOperator::new_in(build_marker_column(), vec![predicate]); + let left_value_column = left.output_schema(&mut plan_arena)[0]; + let right_value_column = right.output_schema(&mut plan_arena)[0]; + let predicate = build_equality_predicate(left_value_column, 0, right_value_column, 1)?; + let mut op = + MarkApplyOperator::new_in(build_marker_column(&mut plan_arena), vec![predicate]); op.set_parameterized_probe(Some(ScalarExpression::column_expr(left_value_column, 0))); let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; - let mut arena = ExecArena::default(); - arena.init_context((&table_cache, &view_cache, &meta_cache), &mut transaction); + let transaction = storage.transaction()?; + let mut arena = ExecArena::new(); + arena.init_context( + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + &transaction, + ); let mut exec = MarkApply { op, @@ -570,7 +636,10 @@ mod tests { left_tuple: Tuple::new(None, vec![DataValue::Int32(2)]), }; - assert_eq!(exec.mark_value(&mut arena)?, DataValue::Boolean(true)); + assert_eq!( + exec.mark_value(&mut arena, &mut plan_arena)?, + DataValue::Boolean(true) + ); assert_eq!( exec.runtime_probe_for(Some(DataValue::Int32(2))), Some(RuntimeIndexProbe::Eq(DataValue::Int32(2))) @@ -581,22 +650,32 @@ mod tests { #[test] fn mark_in_apply_sets_scope_runtime_probe_for_null_value() -> Result<(), DatabaseError> { - let mut left = build_values_with_schema(vec![("left_c1", LogicalType::Integer)], vec![]); + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let mut left = build_values_with_schema( + &mut plan_arena, + vec![("left_c1", LogicalType::Integer)], + vec![], + ); let mut right = build_values_with_schema( + &mut plan_arena, vec![("right_c1", LogicalType::Integer)], vec![vec![DataValue::Null], vec![DataValue::Int32(2)]], ); - let left_value_column = left.output_schema()[0].clone(); - let right_value_column = right.output_schema()[0].clone(); - let predicate = - build_equality_predicate(left_value_column.clone(), 0, right_value_column, 1)?; - let mut op = MarkApplyOperator::new_in(build_marker_column(), vec![predicate]); + let left_value_column = left.output_schema(&mut plan_arena)[0]; + let right_value_column = right.output_schema(&mut plan_arena)[0]; + let predicate = build_equality_predicate(left_value_column, 0, right_value_column, 1)?; + let mut op = + MarkApplyOperator::new_in(build_marker_column(&mut plan_arena), vec![predicate]); op.set_parameterized_probe(Some(ScalarExpression::column_expr(left_value_column, 0))); let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; - let mut arena = ExecArena::default(); - arena.init_context((&table_cache, &view_cache, &meta_cache), &mut transaction); + let transaction = storage.transaction()?; + let mut arena = ExecArena::new(); + arena.init_context( + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + &transaction, + ); let mut exec = MarkApply { op, @@ -605,7 +684,10 @@ mod tests { left_tuple: Tuple::new(None, vec![DataValue::Null]), }; - assert_eq!(exec.mark_value(&mut arena)?, DataValue::Null); + assert_eq!( + exec.mark_value(&mut arena, &mut plan_arena)?, + DataValue::Null + ); assert_eq!( exec.runtime_probe_for(None), Some(RuntimeIndexProbe::Scope { @@ -619,29 +701,34 @@ mod tests { #[test] fn mark_in_apply_appends_boolean_match_column() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); let mut left = build_values( + &mut plan_arena, "left_c1", vec![vec![DataValue::Int32(1)], vec![DataValue::Int32(2)]], ); let mut right = build_values( + &mut plan_arena, "right_c1", vec![vec![DataValue::Int32(2)], vec![DataValue::Int32(3)]], ); - let left_column = left.output_schema()[0].clone(); - let right_column = right.output_schema()[0].clone(); + let left_column = left.output_schema(&mut plan_arena)[0]; + let right_column = right.output_schema(&mut plan_arena)[0]; let predicate = build_equality_predicate(left_column, 0, right_column, 1)?; let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; + let transaction = storage.transaction()?; let tuples = try_collect(execute_input::<_, MarkApply>( ( - MarkApplyOperator::new_in(build_marker_column(), vec![predicate]), + MarkApplyOperator::new_in(build_marker_column(&mut plan_arena), vec![predicate]), left, right, ), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; assert_eq!( @@ -662,29 +749,34 @@ mod tests { #[test] fn mark_in_apply_treats_null_predicate_as_not_matched() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); let mut left = build_values( + &mut plan_arena, "left_c1", vec![vec![DataValue::Int32(1)], vec![DataValue::Int32(2)]], ); let mut right = build_values( + &mut plan_arena, "right_c1", vec![vec![DataValue::Null], vec![DataValue::Int32(2)]], ); - let left_column = left.output_schema()[0].clone(); - let right_column = right.output_schema()[0].clone(); + let left_column = left.output_schema(&mut plan_arena)[0]; + let right_column = right.output_schema(&mut plan_arena)[0]; let predicate = build_equality_predicate(left_column, 0, right_column, 1)?; let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; + let transaction = storage.transaction()?; let tuples = try_collect(execute_input::<_, MarkApply>( ( - MarkApplyOperator::new_in(build_marker_column(), vec![predicate]), + MarkApplyOperator::new_in(build_marker_column(&mut plan_arena), vec![predicate]), left, right, ), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; assert_eq!( @@ -705,20 +797,25 @@ mod tests { #[test] fn mark_in_apply_ignores_null_correlated_predicate_rows() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); let mut left = build_values( + &mut plan_arena, "left_c1", vec![vec![DataValue::Int32(1)], vec![DataValue::Int32(2)]], ); let mut right = build_values_with_schema( + &mut plan_arena, vec![ ("right_c1", LogicalType::Integer), ("right_flag", LogicalType::Integer), ], vec![vec![DataValue::Int32(1), DataValue::Null]], ); - let left_column = left.output_schema()[0].clone(); - let right_value_column = right.output_schema()[0].clone(); - let right_flag_column = right.output_schema()[1].clone(); + let left_column = left.output_schema(&mut plan_arena)[0]; + let right_schema = right.output_schema(&mut plan_arena).clone(); + let right_value_column = right_schema[0]; + let right_flag_column = right_schema[1]; let probe_predicate = build_equality_predicate(left_column, 0, right_value_column, 1)?; let correlated_predicate = ScalarExpression::Binary { @@ -733,18 +830,19 @@ mod tests { }; let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; + let transaction = storage.transaction()?; let tuples = try_collect(execute_input::<_, MarkApply>( ( MarkApplyOperator::new_in( - build_marker_column(), + build_marker_column(&mut plan_arena), vec![probe_predicate, correlated_predicate], ), left, right, ), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; assert_eq!( diff --git a/src/execution/dql/projection.rs b/src/execution/dql/projection.rs index 4d44d3d3..4e3a5311 100644 --- a/src/execution/dql/projection.rs +++ b/src/execution/dql/projection.rs @@ -13,61 +13,55 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::expression::ScalarExpression; use crate::planner::operator::project::ProjectOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; -use crate::types::tuple::Tuple; -use crate::types::value::DataValue; pub struct Projection { exprs: Vec, input: ExecId, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Projection { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Projection { type Input = (ProjectOperator, LogicalPlan); fn into_executor( (ProjectOperator { exprs }, input): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let input = build_read(arena, input, cache, transaction); + let input = build_read(arena, plan_arena, input, cache, transaction); arena.push(ExecNode::Projection(Projection { exprs, input })) } +} - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - if !arena.next_tuple(self.input)? { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Projection { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { + if !arena.next_tuple(self.input, plan_arena)? { arena.finish(); return Ok(()); } - arena.with_projection_tmp(|tuple, projection_tmp| { - projection_tmp.clear(); + arena.with_projection_tmp(|arena, projection_tmp| { + let tuple = arena.result_tuple(); projection_tmp.reserve(self.exprs.len()); for expr in self.exprs.iter() { projection_tmp.push(expr.eval(Some(tuple))?); } + std::mem::swap(&mut arena.result_tuple_mut().values, projection_tmp); Ok::<_, DatabaseError>(()) })?; arena.resume(); Ok(()) } } - -impl Projection { - pub fn projection( - tuple: &Tuple, - exprs: &[ScalarExpression], - ) -> Result, DatabaseError> { - let mut values = Vec::with_capacity(exprs.len()); - - for expr in exprs.iter() { - values.push(expr.eval(Some(tuple))?); - } - Ok(values) - } -} diff --git a/src/execution/dql/scalar_apply.rs b/src/execution/dql/scalar_apply.rs index d8f0ae77..2278c3b2 100644 --- a/src/execution/dql/scalar_apply.rs +++ b/src/execution/dql/scalar_apply.rs @@ -15,7 +15,9 @@ use std::mem; use crate::errors::DatabaseError; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::planner::operator::scalar_apply::ScalarApplyOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; @@ -27,32 +29,39 @@ pub struct ScalarApply { cached_right: Option, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ScalarApply { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for ScalarApply { type Input = (ScalarApplyOperator, LogicalPlan, LogicalPlan); fn into_executor( (_, left_input, right_input): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let left_input = build_read(arena, left_input, cache, transaction); - let right_input = build_read(arena, right_input, cache, transaction); + let left_input = build_read(arena, plan_arena, left_input, cache, transaction); + let right_input = build_read(arena, plan_arena, right_input, cache, transaction); arena.push(ExecNode::ScalarApply(Self { left_input, right_input, cached_right: None, })) } +} - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - Self::load_right_once(&mut self.cached_right, self.right_input, arena)?; +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ScalarApply { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { + Self::load_right_once(&mut self.cached_right, self.right_input, arena, plan_arena)?; let right_tuple = self .cached_right .as_ref() .expect("scalar apply right tuple initialized"); - if !arena.next_tuple(self.left_input)? { + if !arena.next_tuple(self.left_input, plan_arena)? { arena.finish(); return Ok(()); } @@ -70,9 +79,10 @@ impl ScalarApply { cached_right: &mut Option, right_input: ExecId, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { if cached_right.is_none() { - if !arena.next_tuple(right_input)? { + if !arena.next_tuple(right_input, plan_arena)? { return Err(DatabaseError::InvalidValue( "scalar apply right input returned no rows".to_string(), )); @@ -87,7 +97,7 @@ impl ScalarApply { #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { use super::*; - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::execution::{execute_input, try_collect}; use crate::planner::operator::scalar_subquery::ScalarSubqueryOperator; use crate::planner::operator::values::ValuesOperator; @@ -97,18 +107,15 @@ mod tests { use crate::storage::{StatisticsMetaCache, Storage, TableCache, ViewCache}; use crate::types::value::DataValue; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; - use std::hash::RandomState; - use std::sync::Arc; use tempfile::TempDir; - fn build_values(name: &str, rows: Vec>) -> LogicalPlan { + fn build_values( + arena: &mut crate::planner::PlanArena, + name: &str, + rows: Vec>, + ) -> LogicalPlan { let desc = ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(); - let schema_ref = Arc::new(vec![ColumnRef::from(ColumnCatalog::new( - name.to_string(), - true, - desc, - ))]); + let schema_ref = vec![arena.alloc_column(ColumnCatalog::new(name.to_string(), true, desc))]; LogicalPlan::new( Operator::Values(ValuesOperator { rows, schema_ref }), @@ -118,17 +125,17 @@ mod tests { fn build_test_storage() -> Result< ( - Arc, - Arc, - Arc, + TableCache, + ViewCache, + StatisticsMetaCache, TempDir, RocksStorage, ), DatabaseError, > { - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let view_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let meta_cache = crate::storage::StatisticsMetaCache::default(); + let view_cache = crate::storage::ViewCache::default(); + let table_cache = crate::storage::TableCache::default(); let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; @@ -138,7 +145,10 @@ mod tests { #[test] fn scalar_apply_repeats_scalar_result_for_each_left_row() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); let left = build_values( + &mut plan_arena, "left_c1", vec![ vec![crate::types::value::DataValue::Int32(1)], @@ -146,16 +156,18 @@ mod tests { ], ); let right = ScalarSubqueryOperator::build(build_values( + &mut plan_arena, "right_c1", vec![vec![crate::types::value::DataValue::Int32(7)]], )); let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; + let transaction = storage.transaction()?; let tuples = try_collect(execute_input::<_, ScalarApply>( (ScalarApplyOperator, left, right), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; let actual = tuples @@ -177,19 +189,26 @@ mod tests { #[test] fn scalar_apply_repeats_null_scalar_result_for_each_left_row() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); let left = build_values( + &mut plan_arena, "left_c1", vec![vec![DataValue::Int32(1)], vec![DataValue::Int32(2)]], ); - let right = - ScalarSubqueryOperator::build(build_values("right_c1", vec![vec![DataValue::Null]])); + let right = ScalarSubqueryOperator::build(build_values( + &mut plan_arena, + "right_c1", + vec![vec![DataValue::Null]], + )); let (table_cache, view_cache, meta_cache, _temp_dir, storage) = build_test_storage()?; - let mut transaction = storage.transaction()?; + let transaction = storage.transaction()?; let tuples = try_collect(execute_input::<_, ScalarApply>( (ScalarApplyOperator, left, right), - (&table_cache, &view_cache, &meta_cache), - &mut transaction, + crate::execution::empty_context(&table_cache, &view_cache, &meta_cache), + plan_arena, + &transaction, ))?; assert_eq!( diff --git a/src/execution/dql/scalar_subquery.rs b/src/execution/dql/scalar_subquery.rs index 0884e4de..71253e11 100644 --- a/src/execution/dql/scalar_subquery.rs +++ b/src/execution/dql/scalar_subquery.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::planner::operator::scalar_subquery::ScalarSubqueryOperator; use crate::planner::LogicalPlan; use crate::storage::Transaction; @@ -25,32 +27,39 @@ pub struct ScalarSubquery { returned: bool, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ScalarSubquery { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for ScalarSubquery { type Input = (ScalarSubqueryOperator, LogicalPlan); fn into_executor( (_, mut input): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let value_count = input.output_schema().len(); - let input = build_read(arena, input, cache, transaction); + let value_count = input.output_schema(plan_arena).len(); + let input = build_read(arena, plan_arena, input, cache, transaction); arena.push(ExecNode::ScalarSubquery(Self { input, value_count, returned: false, })) } +} - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ScalarSubquery { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { if self.returned { arena.finish(); return Ok(()); } self.returned = true; - let has_first = arena.next_tuple(self.input)?; + let has_first = arena.next_tuple(self.input, plan_arena)?; if !has_first { let output = arena.result_tuple_mut(); output.pk = None; @@ -62,7 +71,7 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ScalarSubquery { return Ok(()); } - if arena.next_tuple(self.input)? { + if arena.next_tuple(self.input, plan_arena)? { return Err(DatabaseError::InvalidValue( "scalar subquery returned more than one row".to_string(), )); diff --git a/src/execution/dql/seq_scan.rs b/src/execution/dql/seq_scan.rs index 8a214926..98b5513c 100644 --- a/src/execution/dql/seq_scan.rs +++ b/src/execution/dql/seq_scan.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor}; use crate::planner::operator::table_scan::TableScanOperator; use crate::storage::{Iter, Transaction, TupleIter}; @@ -32,35 +32,26 @@ impl<'a, T: Transaction + 'a> From for SeqScan<'a, T> { } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for SeqScan<'a, T> { - fn into_executor( - self, - arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, - ) -> ExecId { - arena.push(ExecNode::SeqScan(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for SeqScan<'a, T> { - type Input = TableScanOperator; + type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::SeqScan(SeqScan::from(input))) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - SeqScan::next_tuple(self, arena) + let executor = input; + arena.push(ExecNode::SeqScan(executor)) } } -impl<'a, T: Transaction + 'a> SeqScan<'a, T> { - pub(crate) fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for SeqScan<'a, T> { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { if self.iter.is_none() { let Some(TableScanOperator { table_name, @@ -73,8 +64,11 @@ impl<'a, T: Transaction + 'a> SeqScan<'a, T> { arena.finish(); return Ok(()); }; - self.iter = Some(arena.transaction_mut().read( - arena.table_cache(), + let state = arena.local_state(plan_arena); + self.iter = Some(state.transaction().read( + state.table_codec, + state.plan_arena, + state.context.table_cache, table_name, limit, columns, @@ -82,11 +76,12 @@ impl<'a, T: Transaction + 'a> SeqScan<'a, T> { )?); } + let state = arena.local_state(plan_arena); if self .iter .as_mut() .expect("seq scan iterator initialized") - .next_tuple_into(arena.result_tuple_mut())? + .next_tuple_into(state.table_codec, &mut state.result.tuple)? { arena.resume(); } else { diff --git a/src/execution/dql/set_membership.rs b/src/execution/dql/set_membership.rs index 19cdc170..88e4f4f5 100644 --- a/src/execution/dql/set_membership.rs +++ b/src/execution/dql/set_membership.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::execution::{ - build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor, + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, }; use crate::planner::operator::set_membership::SetMembershipKind; use crate::planner::LogicalPlan; @@ -49,42 +49,42 @@ impl From<(SetMembershipKind, LogicalPlan, LogicalPlan)> for SetMembership { } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for SetMembership { - fn into_executor( - mut self, - arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, - ) -> ExecId { - self.left_input = build_read(arena, self.left_plan.take(), cache, transaction); - self.right_input = build_read(arena, self.right_plan.take(), cache, transaction); - arena.push(ExecNode::SetMembership(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for SetMembership { - type Input = (SetMembershipKind, LogicalPlan, LogicalPlan); + type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - >::into_executor(Self::from(input), arena, cache, transaction) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - SetMembership::next_tuple(self, arena) + let mut executor = input; + executor.left_input = build_read( + arena, + plan_arena, + executor.left_plan.take(), + cache, + transaction, + ); + executor.right_input = build_read( + arena, + plan_arena, + executor.right_plan.take(), + cache, + transaction, + ); + arena.push(ExecNode::SetMembership(executor)) } } -impl SetMembership { - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for SetMembership { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { if !self.built { - while arena.next_tuple(self.right_input)? { + while arena.next_tuple(self.right_input, plan_arena)? { *self .right_counts .entry(arena.result_tuple().clone()) @@ -94,7 +94,7 @@ impl SetMembership { } loop { - if !arena.next_tuple(self.left_input)? { + if !arena.next_tuple(self.left_input, plan_arena)? { arena.finish(); return Ok(()); } @@ -111,7 +111,9 @@ impl SetMembership { } } } +} +impl SetMembership { fn consume_right_match(&mut self, tuple: &Tuple) -> bool { if let Some(count) = self.right_counts.get_mut(tuple) { if *count > 0 { diff --git a/src/execution/dql/show_table.rs b/src/execution/dql/show_table.rs index 8f655ecb..40a19909 100644 --- a/src/execution/dql/show_table.rs +++ b/src/execution/dql/show_table.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::ExecArena; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor}; use crate::storage::{TableIter, Transaction}; use crate::types::value::{DataValue, Utf8Type}; use crate::types::CharLengthUnits; @@ -22,10 +22,30 @@ pub struct ShowTables<'a, T: Transaction + 'a> { pub(crate) metas: Option>, } -impl<'a, T: Transaction + 'a> ShowTables<'a, T> { - pub(crate) fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for ShowTables<'a, T> { + type Input = Self; + + fn into_executor( + input: Self::Input, + arena: &mut ExecArena<'a, T>, + _: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, + ) -> ExecId { + arena.push(ExecNode::ShowTables(input)) + } +} + +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ShowTables<'a, T> { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { if self.metas.is_none() { - self.metas = Some(arena.transaction().tables()?); + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec(); + self.metas = Some(transaction.tables(table_codec)?); } let Some(table) = self diff --git a/src/execution/dql/show_view.rs b/src/execution/dql/show_view.rs index f9c14f08..e22ebd82 100644 --- a/src/execution/dql/show_view.rs +++ b/src/execution/dql/show_view.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::ExecArena; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor}; use crate::storage::{Transaction, ViewIter}; use crate::types::value::{DataValue, Utf8Type}; use crate::types::CharLengthUnits; @@ -22,13 +22,36 @@ pub struct ShowViews<'a, T: Transaction + 'a> { pub(crate) metas: Option>, } -impl<'a, T: Transaction + 'a> ShowViews<'a, T> { - pub(crate) fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for ShowViews<'a, T> { + type Input = Self; + + fn into_executor( + input: Self::Input, + arena: &mut ExecArena<'a, T>, + _: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, + ) -> ExecId { + arena.push(ExecNode::ShowViews(input)) + } +} + +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ShowViews<'a, T> { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { if self.metas.is_none() { - self.metas = Some(arena.transaction().views( - arena.table_cache(), - arena.scala_functions(), - arena.table_functions(), + let context = arena.context(); + let mut state = arena.local_state(plan_arena); + let (transaction, table_codec) = state.transaction_codec(); + self.metas = Some(transaction.views( + table_codec, + context.table_cache(), + plan_arena.table_arena_cell(), + context.scala_functions(), + context.table_functions(), )?); } diff --git a/src/execution/dql/sort.rs b/src/execution/dql/sort.rs index 94983d53..9b032b36 100644 --- a/src/execution/dql/sort.rs +++ b/src/execution/dql/sort.rs @@ -13,7 +13,9 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::planner::operator::sort::{SortField, SortOperator}; use crate::planner::LogicalPlan; use crate::storage::table_codec::BumpBytes; @@ -281,16 +283,17 @@ pub struct Sort { input: ExecId, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Sort { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Sort { type Input = (SortOperator, LogicalPlan); fn into_executor( (SortOperator { sort_fields, limit }, input): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let input = build_read(arena, input, cache, transaction); + let input = build_read(arena, plan_arena, input, cache, transaction); arena.push(ExecNode::Sort(Sort { output: None, arena: Box::::default(), @@ -299,12 +302,18 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Sort { input, })) } +} - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Sort { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { if self.output.is_none() { let mut tuples = NullableVec::new(&self.arena); - while arena.next_tuple(self.input)? { + while arena.next_tuple(self.input, plan_arena)? { let offset = tuples.len(); tuples.put((offset, arena.result_tuple().clone())); } @@ -337,7 +346,7 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Sort { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::errors::DatabaseError; use crate::execution::dql::sort::{radix_sort, NullableVec, SortBy}; use crate::expression::ScalarExpression; @@ -346,7 +355,6 @@ mod test { use crate::types::value::DataValue; use crate::types::LogicalType; use bumpalo::Bump; - use std::sync::Arc; #[test] fn test_radix_sort() { @@ -371,25 +379,28 @@ mod test { #[test] fn test_single_value_desc_and_null_first() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let sort_column = plan_arena.alloc_column(ColumnCatalog::new( + String::new(), + false, + ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), + )); let fn_sort_fields = |asc: bool, nulls_first: bool| { vec![SortField { expr: ScalarExpression::ColumnRef { - column: ColumnRef::from(ColumnCatalog::new( - String::new(), - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), + column: sort_column, position: 0, }, asc, nulls_first, }] }; - let _schema = Arc::new(vec![ColumnRef::from(ColumnCatalog::new( + let _schema = [plan_arena.alloc_column(ColumnCatalog::new( "c1".to_string(), true, ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), - ))]); + ))]; let arena = Bump::new(); let fn_tuples = || { @@ -518,49 +529,51 @@ mod test { #[test] fn test_mixed_value_desc_and_null_first() -> Result<(), DatabaseError> { - let fn_sort_fields = |asc_1: bool, - nulls_first_1: bool, - asc_2: bool, - nulls_first_2: bool| { - vec![ - SortField { - expr: ScalarExpression::ColumnRef { - column: ColumnRef::from(ColumnCatalog::new( - String::new(), - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), - position: 0, + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let sort_column_1 = plan_arena.alloc_column(ColumnCatalog::new( + String::new(), + false, + ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), + )); + let sort_column_2 = plan_arena.alloc_column(ColumnCatalog::new( + String::new(), + false, + ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), + )); + let fn_sort_fields = + |asc_1: bool, nulls_first_1: bool, asc_2: bool, nulls_first_2: bool| { + vec![ + SortField { + expr: ScalarExpression::ColumnRef { + column: sort_column_1, + position: 0, + }, + asc: asc_1, + nulls_first: nulls_first_1, }, - asc: asc_1, - nulls_first: nulls_first_1, - }, - SortField { - expr: ScalarExpression::ColumnRef { - column: ColumnRef::from(ColumnCatalog::new( - String::new(), - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), - position: 1, + SortField { + expr: ScalarExpression::ColumnRef { + column: sort_column_2, + position: 1, + }, + asc: asc_2, + nulls_first: nulls_first_2, }, - asc: asc_2, - nulls_first: nulls_first_2, - }, - ] - }; - let _schema = Arc::new(vec![ - ColumnRef::from(ColumnCatalog::new( + ] + }; + let _schema = [ + plan_arena.alloc_column(ColumnCatalog::new( "c1".to_string(), true, ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), )), - ColumnRef::from(ColumnCatalog::new( + plan_arena.alloc_column(ColumnCatalog::new( "c2".to_string(), true, ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), )), - ]); + ]; let arena = Bump::new(); let fn_tuples = || { diff --git a/src/execution/dql/top_k.rs b/src/execution/dql/top_k.rs index 93b55f8c..3fc9e258 100644 --- a/src/execution/dql/top_k.rs +++ b/src/execution/dql/top_k.rs @@ -14,7 +14,9 @@ use crate::errors::DatabaseError; use crate::execution::dql::sort::BumpVec; -use crate::execution::{build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode}; +use crate::execution::{ + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, +}; use crate::planner::operator::sort::SortField; use crate::planner::operator::top_k::TopKOperator; use crate::planner::LogicalPlan; @@ -96,7 +98,7 @@ pub struct TopK { input: ExecId, } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for TopK { +impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for TopK { type Input = (TopKOperator, LogicalPlan); fn into_executor( @@ -109,10 +111,11 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for TopK { input, ): Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - let input = build_read(arena, input, cache, transaction); + let input = build_read(arena, plan_arena, input, cache, transaction); arena.push(ExecNode::TopK(TopK { output: None, arena: Box::::default(), @@ -122,14 +125,20 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for TopK { input, })) } +} - #[allow(clippy::mutable_key_type)] - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for TopK { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, + ) -> Result<(), DatabaseError> { if self.output.is_none() { let keep_count = self.offset.unwrap_or(0) + self.limit; + #[allow(clippy::mutable_key_type)] let mut set = BTreeSet::new(); - while arena.next_tuple(self.input)? { + while arena.next_tuple(self.input, plan_arena)? { top_sort( &self.arena, &self.sort_fields, @@ -163,7 +172,7 @@ impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for TopK { #[cfg(all(test, not(target_arch = "wasm32")))] #[allow(clippy::mutable_key_type)] mod test { - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::errors::DatabaseError; use crate::execution::dql::top_k::{top_sort, CmpItem}; use crate::expression::ScalarExpression; @@ -176,14 +185,17 @@ mod test { #[test] fn test_top_k_sort() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let sort_column = plan_arena.alloc_column(ColumnCatalog::new( + String::new(), + false, + ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), + )); let fn_sort_fields = |asc: bool, nulls_first: bool| { vec![SortField { expr: ScalarExpression::ColumnRef { - column: ColumnRef::from(ColumnCatalog::new( - String::new(), - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), + column: sort_column, position: 0, }, asc, @@ -346,37 +358,39 @@ mod test { #[test] fn test_top_k_sort_mix_values() -> Result<(), DatabaseError> { - let fn_sort_fields = |asc_1: bool, - nulls_first_1: bool, - asc_2: bool, - nulls_first_2: bool| { - vec![ - SortField { - expr: ScalarExpression::ColumnRef { - column: ColumnRef::from(ColumnCatalog::new( - String::new(), - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), - position: 0, + let table_arena = crate::planner::TableArenaCell::default(); + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let sort_column_1 = plan_arena.alloc_column(ColumnCatalog::new( + String::new(), + false, + ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), + )); + let sort_column_2 = plan_arena.alloc_column(ColumnCatalog::new( + String::new(), + false, + ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), + )); + let fn_sort_fields = + |asc_1: bool, nulls_first_1: bool, asc_2: bool, nulls_first_2: bool| { + vec![ + SortField { + expr: ScalarExpression::ColumnRef { + column: sort_column_1, + position: 0, + }, + asc: asc_1, + nulls_first: nulls_first_1, }, - asc: asc_1, - nulls_first: nulls_first_1, - }, - SortField { - expr: ScalarExpression::ColumnRef { - column: ColumnRef::from(ColumnCatalog::new( - String::new(), - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), - position: 1, + SortField { + expr: ScalarExpression::ColumnRef { + column: sort_column_2, + position: 1, + }, + asc: asc_2, + nulls_first: nulls_first_2, }, - asc: asc_2, - nulls_first: nulls_first_2, - }, - ] - }; + ] + }; let arena = Bump::new(); let fn_asc_1_and_nulls_first_1_and_asc_2_and_nulls_first_2_eq = diff --git a/src/execution/dql/union.rs b/src/execution/dql/union.rs index ac5807df..82484632 100644 --- a/src/execution/dql/union.rs +++ b/src/execution/dql/union.rs @@ -14,7 +14,7 @@ use crate::errors::DatabaseError; use crate::execution::{ - build_read, ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor, + build_read, ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor, }; use crate::planner::LogicalPlan; use crate::storage::Transaction; @@ -39,48 +39,48 @@ impl From<(LogicalPlan, LogicalPlan)> for Union { } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Union { - fn into_executor( - mut self, - arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, - ) -> ExecId { - self.left_input = build_read(arena, self.left_plan.take(), cache, transaction); - self.right_input = build_read(arena, self.right_plan.take(), cache, transaction); - arena.push(ExecNode::Union(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Union { - type Input = (LogicalPlan, LogicalPlan); + type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, + plan_arena: &mut crate::planner::PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, ) -> ExecId { - >::into_executor(Self::from(input), arena, cache, transaction) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - Union::next_tuple(self, arena) + let mut executor = input; + executor.left_input = build_read( + arena, + plan_arena, + executor.left_plan.take(), + cache, + transaction, + ); + executor.right_input = build_read( + arena, + plan_arena, + executor.right_plan.take(), + cache, + transaction, + ); + arena.push(ExecNode::Union(executor)) } } -impl Union { - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Union { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { if self.reading_left { - if arena.next_tuple(self.left_input)? { + if arena.next_tuple(self.left_input, plan_arena)? { arena.resume(); return Ok(()); } self.reading_left = false; } - if arena.next_tuple(self.right_input)? { + if arena.next_tuple(self.right_input, plan_arena)? { arena.resume(); } else { arena.finish(); diff --git a/src/execution/dql/values.rs b/src/execution/dql/values.rs index bcb7b76e..8367754c 100644 --- a/src/execution/dql/values.rs +++ b/src/execution/dql/values.rs @@ -13,16 +13,16 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionCaches, ExecutorNode, ReadExecutor}; +use crate::execution::{ExecArena, ExecId, ExecNode, ExecutionContext, ExecutorNode, ReadExecutor}; use crate::planner::operator::values::ValuesOperator; use crate::storage::Transaction; -use crate::types::tuple::SchemaRef; +use crate::types::tuple::Schema; use crate::types::value::DataValue; use std::mem; pub struct Values { rows: std::vec::IntoIter>, - schema_ref: SchemaRef, + schema_ref: Schema, } impl From for Values { @@ -35,37 +35,25 @@ impl From for Values { } impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Values { - fn into_executor( - self, - arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, - ) -> ExecId { - arena.push(ExecNode::Values(self)) - } -} - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Values { - type Input = ValuesOperator; + type Input = Self; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, + _plan_arena: &mut crate::planner::PlanArena<'a>, + _: ExecutionContext<'_>, + _: &T, ) -> ExecId { - arena.push(ExecNode::Values(Values::from(input))) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - Values::next_tuple(self, arena) + let executor = input; + arena.push(ExecNode::Values(executor)) } } -impl Values { - pub(crate) fn next_tuple<'a, T: Transaction + 'a>( +impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for Values { + fn next_tuple( &mut self, arena: &mut ExecArena<'a, T>, + plan_arena: &mut crate::planner::PlanArena<'a>, ) -> Result<(), DatabaseError> { let Some(mut values) = self.rows.next() else { arena.finish(); @@ -73,9 +61,9 @@ impl Values { }; for (i, value) in values.iter_mut().enumerate() { - let ty = self.schema_ref[i].datatype().clone(); + let ty = plan_arena.column(self.schema_ref[i]).datatype(); - *value = mem::replace(value, DataValue::Null).cast(&ty)?; + *value = mem::replace(value, DataValue::Null).cast(ty)?; } let output = arena.result_tuple_mut(); diff --git a/src/execution/mod.rs b/src/execution/mod.rs index ab8b60e6..db16ec13 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -13,9 +13,12 @@ // limitations under the License. pub(crate) mod ddl; +mod ddl_apply; pub(crate) mod dml; pub(crate) mod dql; +pub(crate) use ddl_apply::DDLApply; + use self::ddl::add_column::AddColumn; use self::ddl::change_column::ChangeColumn; use self::dql::join::nested_loop_join::NestedLoopJoin; @@ -32,7 +35,9 @@ use crate::execution::ddl::drop_table::DropTable; use crate::execution::ddl::drop_view::DropView; use crate::execution::ddl::truncate::Truncate; use crate::execution::dml::analyze::Analyze; +#[cfg(feature = "copy")] use crate::execution::dml::copy_from_file::CopyFromFile; +#[cfg(feature = "copy")] use crate::execution::dml::copy_to_file::CopyToFile; use crate::execution::dml::delete::Delete; use crate::execution::dml::insert::Insert; @@ -58,31 +63,63 @@ use crate::execution::dql::sort::Sort; use crate::execution::dql::top_k::TopK; use crate::execution::dql::union::Union; use crate::execution::dql::values::Values; +use crate::expression::ScalarExpression; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl}; -use crate::planner::LogicalPlan; +use crate::planner::{LogicalPlan, PlanArena}; +use crate::storage::table_codec::TableCodec; use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; use crate::types::index::RuntimeIndexProbe; -use crate::types::tuple::Tuple; +use crate::types::tuple::{Tuple, TupleLike}; use crate::types::value::DataValue; -pub(crate) type ExecutionCaches<'a> = ( - &'a TableCache, - &'a ViewCache, - &'a StatisticsMetaCache, - &'a ScalaFunctions, - &'a TableFunctions, -); - -pub(crate) trait IntoExecutionCaches<'a> { - fn into_execution_caches(self) -> ExecutionCaches<'a>; +#[derive(Clone, Copy)] +pub(crate) struct ExecutionContext<'a> { + table_cache: &'a TableCache, + view_cache: &'a ViewCache, + meta_cache: &'a StatisticsMetaCache, + scala_functions: &'a ScalaFunctions, + table_functions: &'a TableFunctions, } -impl<'a> IntoExecutionCaches<'a> for ExecutionCaches<'a> { - fn into_execution_caches(self) -> ExecutionCaches<'a> { - self +impl<'a> ExecutionContext<'a> { + pub(crate) fn new( + table_cache: &'a TableCache, + view_cache: &'a ViewCache, + meta_cache: &'a StatisticsMetaCache, + scala_functions: &'a ScalaFunctions, + table_functions: &'a TableFunctions, + ) -> Self { + Self { + table_cache, + view_cache, + meta_cache, + scala_functions, + table_functions, + } + } + + pub(crate) fn table_cache(self) -> &'a TableCache { + self.table_cache + } + + pub(crate) fn scala_functions(self) -> &'a ScalaFunctions { + self.scala_functions + } + + pub(crate) fn table_functions(self) -> &'a TableFunctions { + self.table_functions + } + + fn is_same_context(&self, other: ExecutionContext<'_>) -> bool { + std::ptr::eq(self.table_cache, other.table_cache) + && std::ptr::eq(self.view_cache, other.view_cache) + && std::ptr::eq(self.meta_cache, other.meta_cache) + && std::ptr::eq(self.scala_functions, other.scala_functions) + && std::ptr::eq(self.table_functions, other.table_functions) } } + pub(crate) type ExecId = usize; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -107,23 +144,18 @@ impl<'a, T: Transaction + 'a> Executor<'a, T> { Self { arena, root } } - pub(crate) fn next_tuple(&mut self) -> Result, DatabaseError> { - if !self.arena.next_tuple(self.root)? { + pub(crate) fn next_tuple( + &mut self, + plan_arena: &mut PlanArena<'a>, + ) -> Result, DatabaseError> { + if !self.arena.next_tuple(self.root, plan_arena)? { return Ok(None); } Ok(Some(self.arena.result_tuple())) } -} -impl Iterator for Executor<'_, T> { - type Item = Result; - - fn next(&mut self) -> Option { - match self.next_tuple() { - Ok(Some(tuple)) => Some(Ok(tuple.clone())), - Ok(None) => None, - Err(err) => Some(Err(err)), - } + pub(crate) fn take_ddl_apply(&mut self) -> Vec { + self.arena.take_ddl_apply() } } @@ -132,7 +164,9 @@ pub(crate) enum ExecNode<'a, T: Transaction + 'a> { AddColumn(AddColumn), Analyze(Analyze), ChangeColumn(ChangeColumn), + #[cfg(feature = "copy")] CopyFromFile(CopyFromFile), + #[cfg(feature = "copy")] CopyToFile(CopyToFile), CreateIndex(CreateIndex), CreateTable(CreateTable), @@ -173,109 +207,142 @@ pub(crate) enum ExecNode<'a, T: Transaction + 'a> { } pub(crate) trait ExecutorNode<'a, T: Transaction + 'a>: Sized { - type Input; - - fn into_executor( - input: Self::Input, + fn next_tuple( + &mut self, arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, - ) -> ExecId; - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError>; + plan_arena: &mut PlanArena<'a>, + ) -> Result<(), DatabaseError>; } impl<'a, T: Transaction + 'a> ExecNode<'a, T> { - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { + fn next_tuple( + &mut self, + arena: &mut ExecArena<'a, T>, + plan_arena: &mut PlanArena<'a>, + ) -> Result<(), DatabaseError> { match self { ExecNode::AddColumn(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Analyze(exec) => { + >::next_tuple(exec, arena, plan_arena) } - ExecNode::Analyze(exec) => >::next_tuple(exec, arena), ExecNode::ChangeColumn(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } + #[cfg(feature = "copy")] ExecNode::CopyFromFile(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } + #[cfg(feature = "copy")] ExecNode::CopyToFile(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::CreateIndex(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::CreateTable(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::CreateView(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Delete(exec) => { + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Describe(exec) => { + >::next_tuple(exec, arena, plan_arena) } - ExecNode::Delete(exec) => >::next_tuple(exec, arena), - ExecNode::Describe(exec) => >::next_tuple(exec, arena), ExecNode::DropColumn(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::DropIndex(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::DropTable(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::DropView(exec) => { + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Dummy(exec) => { + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Explain(exec) => { + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Filter(exec) => { + >::next_tuple(exec, arena, plan_arena) } - ExecNode::DropView(exec) => >::next_tuple(exec, arena), - ExecNode::Dummy(exec) => >::next_tuple(exec, arena), - ExecNode::Explain(exec) => >::next_tuple(exec, arena), - ExecNode::Filter(exec) => >::next_tuple(exec, arena), ExecNode::FunctionScan(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::HashAgg(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::HashJoin(exec) => { + >::next_tuple(exec, arena, plan_arena) } - ExecNode::HashJoin(exec) => >::next_tuple(exec, arena), ExecNode::IndexScan(exec) => { - as ExecutorNode<'a, T>>::next_tuple(exec, arena) + as ExecutorNode<'a, T>>::next_tuple(exec, arena, plan_arena) + } + ExecNode::Insert(exec) => { + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Limit(exec) => { + >::next_tuple(exec, arena, plan_arena) } - ExecNode::Insert(exec) => >::next_tuple(exec, arena), - ExecNode::Limit(exec) => >::next_tuple(exec, arena), ExecNode::MarkApply(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::NestedLoopJoin(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::Projection(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::ScalarApply(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::ScalarSubquery(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::SetMembership(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) } ExecNode::SeqScan(exec) => { - as ExecutorNode<'a, T>>::next_tuple(exec, arena) + as ExecutorNode<'a, T>>::next_tuple(exec, arena, plan_arena) } ExecNode::ShowTables(exec) => { - as ExecutorNode<'a, T>>::next_tuple(exec, arena) + as ExecutorNode<'a, T>>::next_tuple(exec, arena, plan_arena) } ExecNode::ShowViews(exec) => { - as ExecutorNode<'a, T>>::next_tuple(exec, arena) + as ExecutorNode<'a, T>>::next_tuple(exec, arena, plan_arena) } ExecNode::SimpleAgg(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Sort(exec) => { + >::next_tuple(exec, arena, plan_arena) } - ExecNode::Sort(exec) => >::next_tuple(exec, arena), ExecNode::StreamDistinct(exec) => { - >::next_tuple(exec, arena) + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::TopK(exec) => { + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Truncate(exec) => { + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Union(exec) => { + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Update(exec) => { + >::next_tuple(exec, arena, plan_arena) + } + ExecNode::Values(exec) => { + >::next_tuple(exec, arena, plan_arena) } - ExecNode::TopK(exec) => >::next_tuple(exec, arena), - ExecNode::Truncate(exec) => >::next_tuple(exec, arena), - ExecNode::Union(exec) => >::next_tuple(exec, arena), - ExecNode::Update(exec) => >::next_tuple(exec, arena), - ExecNode::Values(exec) => >::next_tuple(exec, arena), ExecNode::Empty => unreachable!("executor node re-entered while active"), } } @@ -284,41 +351,102 @@ impl<'a, T: Transaction + 'a> ExecNode<'a, T> { pub(crate) struct ExecArena<'a, T: Transaction + 'a> { nodes: Vec>, result: ExecResult, + table_codec: TableCodec, projection_tmp: Vec, - cache: Option>, + context: Option>, transaction: *mut T, runtime_probe_stack: Vec, + ddl_apply: Vec, +} + +pub(crate) struct ExecArenaLocalState<'b, 'a, T: Transaction + 'a> { + transaction: *mut T, + pub(crate) table_codec: &'b mut TableCodec, + pub(crate) context: ExecutionContext<'a>, + pub(crate) result: &'b mut ExecResult, + pub(crate) plan_arena: &'b PlanArena<'a>, + ddl_apply: &'b mut Vec, +} + +impl<'b, 'a, T: Transaction + 'a> ExecArenaLocalState<'b, 'a, T> { + pub(crate) fn transaction(&self) -> &'a T { + unsafe { &*self.transaction } + } + + pub(crate) fn transaction_codec_mut(&mut self) -> (&mut T, &mut TableCodec) { + unsafe { (&mut *self.transaction, &mut *self.table_codec) } + } + + pub(crate) fn transaction_codec(&mut self) -> (&'a T, &mut TableCodec) { + unsafe { (&*self.transaction, &mut *self.table_codec) } + } + + pub(crate) fn write_transaction_codec_ddl_apply_mut( + &mut self, + ) -> (&mut T, &mut TableCodec, &mut Vec) { + unsafe { + ( + &mut *self.transaction, + &mut *self.table_codec, + self.ddl_apply, + ) + } + } } -impl<'a, T: Transaction + 'a> Default for ExecArena<'a, T> { - fn default() -> Self { +impl<'a, T: Transaction + 'a> ExecArena<'a, T> { + pub(crate) fn new() -> Self { Self { nodes: Vec::new(), result: ExecResult::default(), + table_codec: TableCodec::default(), projection_tmp: Vec::new(), - cache: None, + context: None, transaction: std::ptr::null_mut(), runtime_probe_stack: Vec::new(), + ddl_apply: Vec::new(), } } } +pub(crate) fn with_projection_tmp_value<'a, T: Transaction + 'a>( + arena: &mut ExecArena<'a, T>, + tuple: Option<&dyn TupleLike>, + exprs: &[ScalarExpression], + f: impl FnOnce(&mut ExecArena<'a, T>, &DataValue) -> Result<(), DatabaseError>, +) -> Result<(), DatabaseError> { + arena.with_projection_tmp(|arena, projection_tmp| { + { + let tuple = tuple.unwrap_or_else(|| arena.result_tuple() as &dyn TupleLike); + projection_tmp.reserve(exprs.len()); + for expr in exprs.iter() { + projection_tmp.push(expr.eval(Some(tuple))?); + } + } + + if projection_tmp.len() > 1 { + let value = DataValue::Tuple(std::mem::take(projection_tmp), false); + let ret = f(arena, &value); + let DataValue::Tuple(values, _) = value else { + unreachable!() + }; + *projection_tmp = values; + ret?; + } else if let Some(value) = projection_tmp.first() { + f(arena, value)?; + } + Ok(()) + }) +} + impl<'a, T: Transaction + 'a> ExecArena<'a, T> { - pub(crate) fn init_context(&mut self, cache: C, transaction: *mut T) - where - C: IntoExecutionCaches<'a>, - { - let cache = cache.into_execution_caches(); - if let Some(current) = self.cache { - debug_assert!(std::ptr::eq(current.0, cache.0)); - debug_assert!(std::ptr::eq(current.1, cache.1)); - debug_assert!(std::ptr::eq(current.2, cache.2)); - debug_assert!(std::ptr::eq(current.3, cache.3)); - debug_assert!(std::ptr::eq(current.4, cache.4)); - debug_assert_eq!(self.transaction, transaction); + pub(crate) fn init_context(&mut self, context: ExecutionContext<'a>, transaction: &'a T) { + if let Some(current) = &self.context { + debug_assert!(current.is_same_context(context)); + debug_assert_eq!(self.transaction, transaction as *const T as *mut T); } else { - self.cache = Some(cache); - self.transaction = transaction; + self.context = Some(context); + self.transaction = transaction as *const T as *mut T; } } @@ -328,32 +456,52 @@ impl<'a, T: Transaction + 'a> ExecArena<'a, T> { id } - pub(crate) fn table_cache(&self) -> &'a TableCache { - self.cache.expect("execution arena context initialized").0 + pub(crate) fn push_ddl_apply(&mut self, apply: DDLApply) { + self.ddl_apply.push(apply); } - pub(crate) fn view_cache(&self) -> &'a ViewCache { - self.cache.expect("execution arena context initialized").1 + pub(crate) fn take_ddl_apply(&mut self) -> Vec { + std::mem::take(&mut self.ddl_apply) } - pub(crate) fn meta_cache(&self) -> &'a StatisticsMetaCache { - self.cache.expect("execution arena context initialized").2 + pub(crate) fn context(&self) -> ExecutionContext<'a> { + *self + .context + .as_ref() + .expect("execution arena context initialized") } - pub(crate) fn scala_functions(&self) -> &'a ScalaFunctions { - self.cache.expect("execution arena context initialized").3 - } - - pub(crate) fn table_functions(&self) -> &'a TableFunctions { - self.cache.expect("execution arena context initialized").4 + pub(crate) fn table_cache(&self) -> &TableCache { + self.context + .as_ref() + .expect("execution arena context initialized") + .table_cache } pub(crate) fn transaction(&self) -> &'a T { unsafe { &*self.transaction } } - pub(crate) fn transaction_mut(&mut self) -> &'a mut T { - unsafe { &mut *self.transaction } + pub(crate) fn transaction_codec_mut(&mut self) -> (&mut T, &mut TableCodec) { + (unsafe { &mut *self.transaction }, &mut self.table_codec) + } + + pub(crate) fn local_state<'b>( + &'b mut self, + plan_arena: &'b PlanArena<'a>, + ) -> ExecArenaLocalState<'b, 'a, T> { + let context = *self + .context + .as_ref() + .expect("execution arena context initialized"); + ExecArenaLocalState { + transaction: self.transaction, + table_codec: &mut self.table_codec, + context, + result: &mut self.result, + plan_arena, + ddl_apply: &mut self.ddl_apply, + } } pub(crate) fn push_runtime_probe(&mut self, value: RuntimeIndexProbe) { @@ -381,18 +529,16 @@ impl<'a, T: Transaction + 'a> ExecArena<'a, T> { } #[inline] - pub(crate) fn with_projection_tmp( + pub(crate) fn with_projection_tmp( &mut self, - f: impl FnOnce(&Tuple, &mut Vec) -> Result, - ) -> Result { - let ExecArena { - result, - projection_tmp, - .. - } = self; - let ret = f(&result.tuple, projection_tmp)?; - std::mem::swap(&mut result.tuple.values, projection_tmp); - Ok(ret) + f: impl FnOnce(&mut Self, &mut Vec) -> Result, + ) -> Result { + let mut projection_tmp = std::mem::take(&mut self.projection_tmp); + projection_tmp.clear(); + let ret = f(self, &mut projection_tmp); + projection_tmp.clear(); + self.projection_tmp = projection_tmp; + ret } #[inline] @@ -411,10 +557,14 @@ impl<'a, T: Transaction + 'a> ExecArena<'a, T> { self.resume(); } - pub(crate) fn next_tuple(&mut self, id: ExecId) -> Result { + pub(crate) fn next_tuple( + &mut self, + id: ExecId, + plan_arena: &mut PlanArena<'a>, + ) -> Result { self.result.status = None; let mut node = std::mem::replace(&mut self.nodes[id], ExecNode::Empty); - let result = node.next_tuple(self); + let result = node.next_tuple(self, plan_arena); self.nodes[id] = node; result?; @@ -426,218 +576,51 @@ impl<'a, T: Transaction + 'a> ExecArena<'a, T> { } pub(crate) trait ReadExecutor<'a, T: Transaction + 'a>: Sized { - fn into_executor( - self, - arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, - ) -> ExecId; -} - -pub(crate) trait WriteExecutor<'a, T: Transaction + 'a>: Sized { - fn into_executor( - self, - arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, - ) -> ExecId; -} - -macro_rules! impl_read_executor_node_via_from { - ($ty:ty, $input:ty) => { - impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for $ty - where - Self: ReadExecutor<'a, T> + From<$input>, - { - type Input = $input; - - fn into_executor( - input: Self::Input, - arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, - ) -> ExecId { - >::into_executor( - Self::from(input), - arena, - cache, - transaction, - ) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - <$ty>::next_tuple(self, arena) - } - } - }; -} - -macro_rules! impl_write_executor_node_via_from { - ($ty:ty, $input:ty) => { - impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for $ty - where - Self: WriteExecutor<'a, T> + From<$input>, - { - type Input = $input; - - fn into_executor( - input: Self::Input, - arena: &mut ExecArena<'a, T>, - cache: ExecutionCaches<'a>, - transaction: *mut T, - ) -> ExecId { - >::into_executor( - Self::from(input), - arena, - cache, - transaction, - ) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - <$ty>::next_tuple(self, arena) - } - } - }; -} - -impl_read_executor_node_via_from!( - CopyToFile, - ( - crate::planner::operator::copy_to_file::CopyToFileOperator, - LogicalPlan - ) -); - -impl_write_executor_node_via_from!( - AddColumn, - crate::planner::operator::alter_table::add_column::AddColumnOperator -); -impl_write_executor_node_via_from!( - Analyze, - ( - crate::planner::operator::analyze::AnalyzeOperator, - LogicalPlan - ) -); -impl_write_executor_node_via_from!( - ChangeColumn, - crate::planner::operator::alter_table::change_column::ChangeColumnOperator -); -impl_write_executor_node_via_from!( - CopyFromFile, - crate::planner::operator::copy_from_file::CopyFromFileOperator -); -impl_write_executor_node_via_from!( - CreateIndex, - ( - crate::planner::operator::create_index::CreateIndexOperator, - LogicalPlan - ) -); -impl_write_executor_node_via_from!( - CreateTable, - crate::planner::operator::create_table::CreateTableOperator -); -impl_write_executor_node_via_from!( - CreateView, - crate::planner::operator::create_view::CreateViewOperator -); -impl_write_executor_node_via_from!( - Delete, - ( - crate::planner::operator::delete::DeleteOperator, - LogicalPlan - ) -); -impl_write_executor_node_via_from!( - DropColumn, - crate::planner::operator::alter_table::drop_column::DropColumnOperator -); -impl_write_executor_node_via_from!( - DropIndex, - crate::planner::operator::drop_index::DropIndexOperator -); -impl_write_executor_node_via_from!( - DropTable, - crate::planner::operator::drop_table::DropTableOperator -); -impl_write_executor_node_via_from!( - DropView, - crate::planner::operator::drop_view::DropViewOperator -); -impl_write_executor_node_via_from!( - Insert, - ( - crate::planner::operator::insert::InsertOperator, - LogicalPlan - ) -); -impl_write_executor_node_via_from!( - Truncate, - crate::planner::operator::truncate::TruncateOperator -); -impl_write_executor_node_via_from!( - Update, - ( - crate::planner::operator::update::UpdateOperator, - LogicalPlan - ) -); - -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ShowTables<'a, T> { - type Input = Self; + type Input; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, - ) -> ExecId { - arena.push(ExecNode::ShowTables(input)) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - ShowTables::next_tuple(self, arena) - } + plan_arena: &mut PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, + ) -> ExecId; } -impl<'a, T: Transaction + 'a> ExecutorNode<'a, T> for ShowViews<'a, T> { - type Input = Self; +pub(crate) trait WriteExecutor<'a, T: Transaction + 'a>: Sized { + type Input; fn into_executor( input: Self::Input, arena: &mut ExecArena<'a, T>, - _: ExecutionCaches<'a>, - _: *mut T, - ) -> ExecId { - arena.push(ExecNode::ShowViews(input)) - } - - fn next_tuple(&mut self, arena: &mut ExecArena<'a, T>) -> Result<(), DatabaseError> { - ShowViews::next_tuple(self, arena) - } + plan_arena: &mut PlanArena<'a>, + cache: ExecutionContext<'_>, + transaction: &T, + ) -> ExecId; } -pub(crate) fn build_read<'a, T: Transaction + 'a>( +pub(crate) fn build_read<'a, T>( arena: &mut ExecArena<'a, T>, + plan_arena: &mut PlanArena<'a>, plan: LogicalPlan, - cache: ExecutionCaches<'a>, - transaction: *mut T, -) -> ExecId { - arena.init_context(cache, transaction); - + cache: ExecutionContext<'_>, + transaction: &T, +) -> ExecId +where + T: Transaction + 'a, +{ let LogicalPlan { operator, childrens, physical_option, - _output_schema_ref, + .. } = plan; match operator { - Operator::Dummy => >::into_executor( + Operator::Dummy => >::into_executor( Dummy::default(), arena, + plan_arena, cache, transaction, ), @@ -645,9 +628,10 @@ pub(crate) fn build_read<'a, T: Transaction + 'a>( let input = childrens.pop_only(); if op.groupby_exprs.is_empty() { - >::into_executor( + >::into_executor( (op, input), arena, + plan_arena, cache, transaction, ) @@ -661,41 +645,46 @@ pub(crate) fn build_read<'a, T: Transaction + 'a>( }) ) { - >::into_executor( + >::into_executor( (op, input), arena, + plan_arena, cache, transaction, ) } else { - >::into_executor( + >::into_executor( (op, input), arena, + plan_arena, cache, transaction, ) } } - Operator::Filter(op) => >::into_executor( + Operator::Filter(op) => >::into_executor( (op, childrens.pop_only()), arena, + plan_arena, cache, transaction, ), Operator::ScalarApply(op) => { let (left, right) = childrens.pop_twins(); - >::into_executor( + >::into_executor( (op, left, right), arena, + plan_arena, cache, transaction, ) } Operator::MarkApply(op) => { let (left, right) = childrens.pop_twins(); - >::into_executor( + >::into_executor( (op, left, right), arena, + plan_arena, cache, transaction, ) @@ -714,30 +703,34 @@ pub(crate) fn build_read<'a, T: Transaction + 'a>( let (left, right) = childrens.pop_twins(); if use_hash_join { - >::into_executor( - (op, left, right), + >::into_executor( + HashJoin::from((op, left, right)), arena, + plan_arena, cache, transaction, ) } else { - >::into_executor( - (op, left, right), + >::into_executor( + NestedLoopJoin::from((op, left, right)), arena, + plan_arena, cache, transaction, ) } } - Operator::Project(op) => >::into_executor( + Operator::Project(op) => >::into_executor( (op, childrens.pop_only()), arena, + plan_arena, cache, transaction, ), - Operator::ScalarSubquery(op) => >::into_executor( + Operator::ScalarSubquery(op) => >::into_executor( (op, childrens.pop_only()), arena, + plan_arena, cache, transaction, ), @@ -748,79 +741,106 @@ pub(crate) fn build_read<'a, T: Transaction + 'a>( }) = physical_option { if let Some(lookup) = index_info.lookup.clone() { - return as ExecutorNode<'a, T>>::into_executor( - ( + return as ReadExecutor<'a, T>>::into_executor( + IndexScan::from(( op, - index_info.meta.clone(), + index_info.meta, lookup, index_info.covered_deserializers.clone(), index_info.cover_mapping.clone(), - ), + )), arena, + plan_arena, cache, transaction, ); } } - as ExecutorNode<'a, T>>::into_executor(op, arena, cache, transaction) - } - Operator::FunctionScan(op) => { - >::into_executor(op, arena, cache, transaction) + as ReadExecutor<'a, T>>::into_executor( + SeqScan::from(op), + arena, + plan_arena, + cache, + transaction, + ) } - Operator::Sort(op) => >::into_executor( + Operator::FunctionScan(op) => >::into_executor( + FunctionScan::from(op), + arena, + plan_arena, + cache, + transaction, + ), + Operator::Sort(op) => >::into_executor( (op, childrens.pop_only()), arena, + plan_arena, cache, transaction, ), - Operator::Limit(op) => >::into_executor( + Operator::Limit(op) => >::into_executor( (op, childrens.pop_only()), arena, + plan_arena, cache, transaction, ), - Operator::TopK(op) => >::into_executor( + Operator::TopK(op) => >::into_executor( (op, childrens.pop_only()), arena, + plan_arena, cache, transaction, ), - Operator::Values(op) => { - >::into_executor(op, arena, cache, transaction) - } - Operator::ShowTable => as ExecutorNode<'a, T>>::into_executor( + Operator::Values(op) => >::into_executor( + Values::from(op), + arena, + plan_arena, + cache, + transaction, + ), + Operator::ShowTable => as ReadExecutor<'a, T>>::into_executor( ShowTables { metas: None }, arena, + plan_arena, cache, transaction, ), - Operator::ShowView => as ExecutorNode<'a, T>>::into_executor( + Operator::ShowView => as ReadExecutor<'a, T>>::into_executor( ShowViews { metas: None }, arena, + plan_arena, cache, transaction, ), - Operator::Explain => >::into_executor( - childrens.pop_only(), + Operator::Explain => >::into_executor( + Explain::from(childrens.pop_only()), arena, + plan_arena, cache, transaction, ), - Operator::Describe(op) => { - >::into_executor(op, arena, cache, transaction) - } - Operator::Union(_) => >::into_executor( - childrens.pop_twins(), + Operator::Describe(op) => >::into_executor( + Describe::from(op), arena, + plan_arena, + cache, + transaction, + ), + Operator::Union(_) => >::into_executor( + Union::from(childrens.pop_twins()), + arena, + plan_arena, cache, transaction, ), Operator::SetMembership(op) => { let (left, right) = childrens.pop_twins(); - >::into_executor( - (op.kind, left, right), + >::into_executor( + SetMembership::from((op.kind, left, right)), arena, + plan_arena, cache, transaction, ) @@ -829,100 +849,168 @@ pub(crate) fn build_read<'a, T: Transaction + 'a>( } } -pub(crate) fn build_write<'a, T: Transaction + 'a>( +pub(crate) fn build_write<'a, T>( arena: &mut ExecArena<'a, T>, + plan_arena: &mut PlanArena<'a>, plan: LogicalPlan, - cache: ExecutionCaches<'a>, - transaction: *mut T, -) -> ExecId { + cache: ExecutionContext<'a>, + transaction: &'a mut T, +) -> ExecId +where + T: Transaction + 'a, +{ arena.init_context(cache, transaction); - + let transaction_ref: &T = transaction; let LogicalPlan { operator, childrens, physical_option, - _output_schema_ref, + .. } = plan; match operator { Operator::Insert(op) => { let input = childrens.pop_only(); - >::into_executor((op, input), arena, cache, transaction) + >::into_executor( + Insert::from((op, input)), + arena, + plan_arena, + cache, + transaction_ref, + ) } Operator::Update(op) => { let input = childrens.pop_only(); - >::into_executor((op, input), arena, cache, transaction) + >::into_executor( + Update::from((op, input)), + arena, + plan_arena, + cache, + transaction_ref, + ) } Operator::Delete(op) => { let input = childrens.pop_only(); - >::into_executor((op, input), arena, cache, transaction) - } - Operator::AddColumn(op) => { - >::into_executor(op, arena, cache, transaction) - } - Operator::ChangeColumn(op) => { - >::into_executor(op, arena, cache, transaction) - } - Operator::DropColumn(op) => { - >::into_executor(op, arena, cache, transaction) - } - Operator::CreateTable(op) => { - >::into_executor(op, arena, cache, transaction) + >::into_executor( + Delete::from((op, input)), + arena, + plan_arena, + cache, + transaction_ref, + ) } + Operator::AddColumn(op) => >::into_executor( + AddColumn::from(op), + arena, + plan_arena, + cache, + transaction_ref, + ), + Operator::ChangeColumn(op) => >::into_executor( + ChangeColumn::from(op), + arena, + plan_arena, + cache, + transaction_ref, + ), + Operator::DropColumn(op) => >::into_executor( + DropColumn::from(op), + arena, + plan_arena, + cache, + transaction_ref, + ), + Operator::CreateTable(op) => >::into_executor( + CreateTable::from(op), + arena, + plan_arena, + cache, + transaction_ref, + ), Operator::CreateIndex(op) => { let input = childrens.pop_only(); - >::into_executor( - (op, input), + >::into_executor( + CreateIndex::from((op, input)), arena, + plan_arena, cache, - transaction, + transaction_ref, ) } - Operator::CreateView(op) => { - >::into_executor(op, arena, cache, transaction) - } - Operator::DropTable(op) => { - >::into_executor(op, arena, cache, transaction) - } - Operator::DropView(op) => { - >::into_executor(op, arena, cache, transaction) - } - Operator::DropIndex(op) => { - >::into_executor(op, arena, cache, transaction) - } - Operator::Truncate(op) => { - >::into_executor(op, arena, cache, transaction) - } - Operator::CopyFromFile(op) => { - >::into_executor(op, arena, cache, transaction) - } + Operator::CreateView(op) => >::into_executor( + CreateView::from(op), + arena, + plan_arena, + cache, + transaction_ref, + ), + Operator::DropTable(op) => >::into_executor( + DropTable::from(op), + arena, + plan_arena, + cache, + transaction_ref, + ), + Operator::DropView(op) => >::into_executor( + DropView::from(op), + arena, + plan_arena, + cache, + transaction_ref, + ), + Operator::DropIndex(op) => >::into_executor( + DropIndex::from(op), + arena, + plan_arena, + cache, + transaction_ref, + ), + Operator::Truncate(op) => >::into_executor( + Truncate::from(op), + arena, + plan_arena, + cache, + transaction_ref, + ), + #[cfg(feature = "copy")] + Operator::CopyFromFile(op) => >::into_executor( + CopyFromFile::from(op), + arena, + plan_arena, + cache, + transaction_ref, + ), + #[cfg(feature = "copy")] Operator::CopyToFile(op) => { let input = childrens.pop_only(); - >::into_executor( - (op, input), + >::into_executor( + CopyToFile::from((op, input)), arena, + plan_arena, cache, - transaction, + transaction_ref, ) } Operator::Analyze(op) => { let input = childrens.pop_only(); - >::into_executor((op, input), arena, cache, transaction) + >::into_executor( + Analyze::from((op, input)), + arena, + plan_arena, + cache, + transaction_ref, + ) } operator => { - let plan = LogicalPlan { - operator, - childrens, - physical_option, - _output_schema_ref, - }; - build_read(arena, plan, cache, transaction) + let mut plan = LogicalPlan::new(operator, *childrens); + plan.physical_option = physical_option; + build_read(arena, plan_arena, plan, cache, transaction_ref) } } } @@ -936,103 +1024,146 @@ mod test_utils { static EMPTY_TABLE_FUNCTIONS: std::sync::LazyLock = std::sync::LazyLock::new(TableFunctions::default); - impl<'a> IntoExecutionCaches<'a> for (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache) { - fn into_execution_caches(self) -> ExecutionCaches<'a> { - ( - self.0, - self.1, - self.2, - &EMPTY_SCALA_FUNCTIONS, - &EMPTY_TABLE_FUNCTIONS, - ) + pub(crate) fn empty_context<'a>( + table_cache: &'a TableCache, + view_cache: &'a ViewCache, + meta_cache: &'a StatisticsMetaCache, + ) -> ExecutionContext<'a> { + ExecutionContext::new( + table_cache, + view_cache, + meta_cache, + &EMPTY_SCALA_FUNCTIONS, + &EMPTY_TABLE_FUNCTIONS, + ) + } + + pub(crate) struct TestExecutor<'a, T: Transaction + 'a> { + executor: Executor<'a, T>, + plan_arena: PlanArena<'a>, + } + + impl<'a, T: Transaction + 'a> TestExecutor<'a, T> { + pub(crate) fn next_tuple(&mut self) -> Result, DatabaseError> { + self.executor.next_tuple(&mut self.plan_arena) } } - impl<'a> IntoExecutionCaches<'a> - for ( - &'a std::sync::Arc, - &'a std::sync::Arc, - &'a std::sync::Arc, - ) - { - fn into_execution_caches(self) -> ExecutionCaches<'a> { - ( - self.0.as_ref(), - self.1.as_ref(), - self.2.as_ref(), - &EMPTY_SCALA_FUNCTIONS, - &EMPTY_TABLE_FUNCTIONS, - ) + impl Iterator for TestExecutor<'_, T> { + type Item = Result; + + fn next(&mut self) -> Option { + match self.next_tuple() { + Ok(Some(tuple)) => Some(Ok(tuple.clone())), + Ok(None) => None, + Err(err) => Some(Err(err)), + } } } pub(crate) fn execute<'a, T, E>( executor: E, - cache: impl IntoExecutionCaches<'a>, - transaction: *mut T, - ) -> Executor<'a, T> + cache: ExecutionContext<'a>, + mut plan_arena: PlanArena<'a>, + transaction: &'a T, + ) -> TestExecutor<'a, T> where T: Transaction + 'a, - E: ReadExecutor<'a, T>, + E: ReadExecutor<'a, T, Input = E>, { - let cache = cache.into_execution_caches(); - let mut arena = ExecArena::default(); + let mut arena = ExecArena::new(); arena.init_context(cache, transaction); - let root = executor.into_executor(&mut arena, cache, transaction); - Executor::new(arena, root) + let root = >::into_executor( + executor, + &mut arena, + &mut plan_arena, + cache, + transaction, + ); + TestExecutor { + executor: Executor::new(arena, root), + plan_arena, + } } pub(crate) fn execute_mut<'a, T, E>( executor: E, - cache: impl IntoExecutionCaches<'a>, - transaction: *mut T, - ) -> Executor<'a, T> + cache: ExecutionContext<'a>, + mut plan_arena: PlanArena<'a>, + transaction: &'a T, + ) -> TestExecutor<'a, T> where T: Transaction + 'a, - E: WriteExecutor<'a, T>, + E: WriteExecutor<'a, T, Input = E>, { - let cache = cache.into_execution_caches(); - let mut arena = ExecArena::default(); + let mut arena = ExecArena::new(); arena.init_context(cache, transaction); - let root = executor.into_executor(&mut arena, cache, transaction); - Executor::new(arena, root) + let root = >::into_executor( + executor, + &mut arena, + &mut plan_arena, + cache, + transaction, + ); + TestExecutor { + executor: Executor::new(arena, root), + plan_arena, + } } pub(crate) fn execute_input<'a, T, E>( input: E::Input, - cache: impl IntoExecutionCaches<'a>, - transaction: *mut T, - ) -> Executor<'a, T> + cache: ExecutionContext<'a>, + mut plan_arena: PlanArena<'a>, + transaction: &'a T, + ) -> TestExecutor<'a, T> where T: Transaction + 'a, - E: ExecutorNode<'a, T>, + E: ReadExecutor<'a, T>, { - let cache = cache.into_execution_caches(); - let mut arena = ExecArena::default(); + let mut arena = ExecArena::new(); arena.init_context(cache, transaction); - let root = E::into_executor(input, &mut arena, cache, transaction); - Executor::new(arena, root) + let root = >::into_executor( + input, + &mut arena, + &mut plan_arena, + cache, + transaction, + ); + TestExecutor { + executor: Executor::new(arena, root), + plan_arena, + } } #[allow(dead_code)] pub(crate) fn execute_input_mut<'a, T, E>( input: E::Input, - cache: impl IntoExecutionCaches<'a>, - transaction: *mut T, - ) -> Executor<'a, T> + cache: ExecutionContext<'a>, + mut plan_arena: PlanArena<'a>, + transaction: &'a T, + ) -> TestExecutor<'a, T> where T: Transaction + 'a, - E: ExecutorNode<'a, T>, + E: WriteExecutor<'a, T>, { - let cache = cache.into_execution_caches(); - let mut arena = ExecArena::default(); + let mut arena = ExecArena::new(); arena.init_context(cache, transaction); - let root = E::into_executor(input, &mut arena, cache, transaction); - Executor::new(arena, root) + let root = >::into_executor( + input, + &mut arena, + &mut plan_arena, + cache, + transaction, + ); + TestExecutor { + executor: Executor::new(arena, root), + plan_arena, + } } pub fn try_collect( - executor: Executor<'_, T>, + executor: TestExecutor<'_, T>, ) -> Result, DatabaseError> { let mut executor = executor; let mut tuples = Vec::new(); @@ -1046,4 +1177,6 @@ mod test_utils { #[cfg(all(test, not(target_arch = "wasm32")))] #[allow(unused_imports)] -pub(crate) use test_utils::{execute, execute_input, execute_input_mut, execute_mut, try_collect}; +pub(crate) use test_utils::{ + empty_context, execute, execute_input, execute_input_mut, execute_mut, try_collect, +}; diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 045dfa96..e2e04548 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -58,7 +58,7 @@ impl ScalarExpression { } => { let value = expr.eval(tuple)?; if let Some(evaluator) = evaluator { - evaluator.eval_cast(&value) + evaluator.eval(&value) } else { Ok(value) } diff --git a/src/expression/function/mod.rs b/src/expression/function/mod.rs index f8f71a77..beb07df4 100644 --- a/src/expression/function/mod.rs +++ b/src/expression/function/mod.rs @@ -14,13 +14,12 @@ use crate::types::LogicalType; use kite_sql_serde_macros::ReferenceSerialization; -use serde::{Deserialize, Serialize}; use std::sync::Arc; pub mod scala; pub mod table; -#[derive(Debug, Eq, PartialEq, Hash, Clone, Serialize, Deserialize, ReferenceSerialization)] +#[derive(Debug, Eq, PartialEq, Hash, Clone, ReferenceSerialization)] pub struct FunctionSummary { pub name: Arc, pub arg_types: Vec, diff --git a/src/expression/function/table.rs b/src/expression/function/table.rs index 308ee308..3c77c831 100644 --- a/src/expression/function/table.rs +++ b/src/expression/function/table.rs @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::TableCatalog; +use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::expression::function::FunctionSummary; use crate::expression::ScalarExpression; -use crate::types::tuple::{SchemaRef, Tuple}; +use crate::planner::TableArena; +use crate::types::tuple::{Schema, Tuple}; use kite_sql_serde_macros::ReferenceSerialization; use std::fmt::Debug; use std::hash::{Hash, Hasher}; @@ -37,6 +38,12 @@ impl Deref for ArcTableFunctionImpl { #[derive(Debug, Clone, ReferenceSerialization)] pub struct TableFunction { pub(crate) args: Vec, + pub(crate) catalog: TableFunctionCatalog, +} + +#[derive(Debug, Clone, ReferenceSerialization)] +pub struct TableFunctionCatalog { + pub(crate) schema: Schema, pub(crate) inner: ArcTableFunctionImpl, } @@ -61,21 +68,25 @@ pub trait TableFunctionImpl: Debug + Send + Sync { fn summary(&self) -> &FunctionSummary; - fn output_schema(&self) -> &SchemaRef; - - fn table(&self) -> &TableCatalog; + fn output_schema_into( + &self, + table_name: &TableName, + table_arena: &mut TableArena, + schema: &mut Schema, + ); } impl TableFunction { pub fn summary(&self) -> &FunctionSummary { - self.inner.summary() + self.catalog.inner.summary() } - pub fn output_schema(&self) -> &SchemaRef { - self.inner.output_schema() + pub fn output_schema_into(&self, schema: &mut Schema) { + schema.clear(); + schema.extend(self.catalog.schema.iter().copied()); } - pub fn table(&self) -> &TableCatalog { - self.inner.table() + pub fn inner(&self) -> &ArcTableFunctionImpl { + &self.catalog.inner } } diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 67ecf183..accd0040 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -19,18 +19,22 @@ use crate::expression::function::scala::ScalarFunction; use crate::expression::function::table::TableFunction; use crate::expression::visitor::{walk_expr, Visitor}; use crate::expression::visitor_mut::VisitorMut; +use crate::planner::operator::sort::SortField; +use crate::planner::{MetaArena, PlanArena}; use crate::types::evaluator::{ - binary_create, cast_create, unary_create, BinaryEvaluatorBox, CastEvaluatorBox, - UnaryEvaluatorBox, + binary_create, cast_create, unary_create, BinaryEvaluatorRef, CastEvaluatorRef, + UnaryEvaluatorRef, }; use crate::types::value::DataValue; use crate::types::{CharLengthUnits, LogicalType}; use itertools::Itertools; use kite_sql_serde_macros::ReferenceSerialization; -use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator}; +#[cfg(feature = "decimal")] +use rust_decimal::Decimal; use std::borrow::Cow; use std::fmt::{Debug, Formatter}; use std::hash::Hash; +use std::sync::Arc; use std::{fmt, mem}; pub mod agg; @@ -41,23 +45,13 @@ pub mod simplify; pub mod visitor; pub mod visitor_mut; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TrimWhereField { Both, Leading, Trailing, } -impl From for TrimWhereField { - fn from(value: sqlparser::ast::TrimWhereField) -> Self { - match value { - sqlparser::ast::TrimWhereField::Both => Self::Both, - sqlparser::ast::TrimWhereField::Leading => Self::Leading, - sqlparser::ast::TrimWhereField::Trailing => Self::Trailing, - } - } -} - #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub enum AliasType { Name(String), @@ -82,7 +76,7 @@ pub enum ScalarExpression { TypeCast { expr: Box, ty: LogicalType, - evaluator: Option, + evaluator: Option, }, IsNull { negated: bool, @@ -91,14 +85,14 @@ pub enum ScalarExpression { Unary { op: UnaryOperator, expr: Box, - evaluator: Option, + evaluator: Option, ty: LogicalType, }, Binary { op: BinaryOperator, left_expr: Box, right_expr: Box, - evaluator: Option, + evaluator: Option, ty: LogicalType, }, AggCall { @@ -165,17 +159,133 @@ pub enum ScalarExpression { }, } -pub struct BindEvaluator; +impl From for ScalarExpression { + fn from(value: DataValue) -> Self { + ScalarExpression::Constant(value) + } +} + +macro_rules! impl_scalar_expression_from_data_value { + ($($ty:ty),+ $(,)?) => { + $( + impl From<$ty> for ScalarExpression { + fn from(value: $ty) -> Self { + ScalarExpression::Constant(DataValue::from(value)) + } + } + )+ + }; +} + +impl_scalar_expression_from_data_value!( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + String, + Option, + Option, + Option, + Option, + Option, + Option, + Option, + Option, + Option, + Option, + Option, + Option, +); +#[cfg(feature = "decimal")] +impl_scalar_expression_from_data_value!(Decimal, Option); + +impl From<&str> for ScalarExpression { + fn from(value: &str) -> Self { + ScalarExpression::Constant(DataValue::from(value.to_string())) + } +} + +impl From> for ScalarExpression { + fn from(value: Option<&str>) -> Self { + ScalarExpression::Constant(value.map(str::to_string).into()) + } +} + +impl From> for ScalarExpression { + fn from(value: Arc) -> Self { + ScalarExpression::Constant(DataValue::from(value.to_string())) + } +} + +impl From>> for ScalarExpression { + fn from(value: Option>) -> Self { + ScalarExpression::Constant(value.map(|value| value.to_string()).into()) + } +} + +#[cfg(feature = "time")] +mod chrono_scalar_expression { + use super::ScalarExpression; + use crate::types::value::DataValue; + use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; + + impl From for ScalarExpression { + fn from(value: NaiveDate) -> Self { + ScalarExpression::Constant(DataValue::from(&value)) + } + } + + impl From> for ScalarExpression { + fn from(value: Option) -> Self { + ScalarExpression::Constant(DataValue::from(value.as_ref())) + } + } -impl VisitorMut<'_> for BindEvaluator { + impl From for ScalarExpression { + fn from(value: NaiveDateTime) -> Self { + ScalarExpression::Constant(DataValue::from(&value)) + } + } + + impl From> for ScalarExpression { + fn from(value: Option) -> Self { + ScalarExpression::Constant(DataValue::from(value.as_ref())) + } + } + + impl From for ScalarExpression { + fn from(value: NaiveTime) -> Self { + ScalarExpression::Constant(DataValue::from(&value)) + } + } + + impl From> for ScalarExpression { + fn from(value: Option) -> Self { + ScalarExpression::Constant(DataValue::from(value.as_ref())) + } + } +} + +pub struct BindEvaluator<'a, 'p> { + pub(crate) arena: &'a PlanArena<'p>, +} + +impl VisitorMut<'_> for BindEvaluator<'_, '_> { fn visit_type_cast( &mut self, expr: &'_ mut ScalarExpression, ty: &'_ mut LogicalType, - evaluator: &'_ mut Option, + evaluator: &'_ mut Option, ) -> Result<(), DatabaseError> { self.visit(expr)?; - let from = expr.return_type(); + let from = expr.return_type(self.arena); *evaluator = if from.as_ref() == ty { None } else { @@ -189,12 +299,12 @@ impl VisitorMut<'_> for BindEvaluator { &mut self, op: &'_ mut UnaryOperator, expr: &'_ mut ScalarExpression, - evaluator: &'_ mut Option, + evaluator: &'_ mut Option, _ty: &'_ mut LogicalType, ) -> Result<(), DatabaseError> { self.visit(expr)?; - let ty = expr.return_type(); + let ty = expr.return_type(self.arena); if ty.is_unsigned_numeric() { let target_ty = match ty.as_ref() { LogicalType::UTinyint => LogicalType::Tinyint, @@ -206,9 +316,10 @@ impl VisitorMut<'_> for BindEvaluator { *expr = ScalarExpression::type_cast( mem::replace(expr, ScalarExpression::Empty), Cow::Owned(target_ty), + self.arena, )?; } - *evaluator = Some(unary_create(expr.return_type(), *op)?); + *evaluator = Some(unary_create(expr.return_type(self.arena), *op)?); Ok(()) } @@ -218,20 +329,21 @@ impl VisitorMut<'_> for BindEvaluator { op: &'_ mut BinaryOperator, left_expr: &'_ mut ScalarExpression, right_expr: &'_ mut ScalarExpression, - evaluator: &'_ mut Option, + evaluator: &'_ mut Option, _ty: &'_ mut LogicalType, ) -> Result<(), DatabaseError> { self.visit(left_expr)?; self.visit(right_expr)?; - let left_ty = left_expr.return_type().into_owned(); - let right_ty = right_expr.return_type().into_owned(); + let left_ty = left_expr.return_type(self.arena).into_owned(); + let right_ty = right_expr.return_type(self.arena).into_owned(); let ty = LogicalType::max_logical_type(&left_ty, &right_ty)?; let fn_cast = |expr: &mut ScalarExpression, ty: &LogicalType| -> Result<(), DatabaseError> { *expr = ScalarExpression::type_cast( mem::replace(expr, ScalarExpression::Empty), Cow::Borrowed(ty), + self.arena, )?; Ok(()) }; @@ -274,6 +386,22 @@ impl Visitor<'_> for HasCountStar { } impl ScalarExpression { + pub fn asc(self) -> SortField { + SortField::from(self).asc() + } + + pub fn desc(self) -> SortField { + SortField::from(self).desc() + } + + pub fn nulls_first(self) -> SortField { + SortField::from(self).nulls_first() + } + + pub fn nulls_last(self) -> SortField { + SortField::from(self).nulls_last() + } + pub fn column_expr(column: ColumnRef, position: usize) -> ScalarExpression { ScalarExpression::ColumnRef { column, position } } @@ -281,8 +409,9 @@ impl ScalarExpression { pub fn type_cast( expr: ScalarExpression, ty: Cow<'_, LogicalType>, + arena: &PlanArena, ) -> Result { - let from = expr.return_type(); + let from = expr.return_type(arena); if from.as_ref() == ty.as_ref() { return Ok(expr); } @@ -295,7 +424,7 @@ impl ScalarExpression { }) } - pub(crate) fn eq_ignore_colref_pos(&self, other: &ScalarExpression) -> bool { + pub(crate) fn eq_ignore_colref_pos(&self, other: &ScalarExpression, arena: &PlanArena) -> bool { match (self.unpack_alias_ref(), other.unpack_alias_ref()) { ( ScalarExpression::ColumnRef { @@ -304,7 +433,7 @@ impl ScalarExpression { ScalarExpression::ColumnRef { column: rhs_column, .. }, - ) => lhs_column.same_column(rhs_column), + ) => arena.same_column(*lhs_column, *rhs_column), (lhs, rhs) => lhs == rhs, } } @@ -337,10 +466,12 @@ impl ScalarExpression { } } - pub fn return_type(&self) -> Cow<'_, LogicalType> { + pub fn return_type<'a>(&'a self, arena: &'a PlanArena<'_>) -> Cow<'a, LogicalType> { match self { ScalarExpression::Constant(v) => Cow::Owned(v.logical_type()), - ScalarExpression::ColumnRef { column, .. } => Cow::Borrowed(column.datatype()), + ScalarExpression::ColumnRef { column, .. } => { + Cow::Borrowed(arena.column(*column).datatype()) + } ScalarExpression::Binary { ty: return_type, .. } @@ -378,12 +509,12 @@ impl ScalarExpression { ScalarExpression::Trim { .. } => { Cow::Owned(LogicalType::Varchar(None, CharLengthUnits::Characters)) } - ScalarExpression::Alias { expr, .. } => expr.return_type(), + ScalarExpression::Alias { expr, .. } => expr.return_type(arena), ScalarExpression::Empty | ScalarExpression::TableFunction(_) => unreachable!(), ScalarExpression::Tuple(exprs) => { let types = exprs .iter() - .map(|expr| expr.return_type().into_owned()) + .map(|expr| expr.return_type(arena).into_owned()) .collect_vec(); Cow::Owned(LogicalType::Tuple(types)) @@ -394,16 +525,22 @@ impl ScalarExpression { } } - pub fn visit_referenced_columns( + pub fn visit_referenced_columns( &self, - only_column_ref: bool, - f: &mut impl FnMut(&ColumnRef) -> bool, + arena: &mut A, + f: &mut impl FnMut(&mut A, &ColumnRef) -> bool, ) -> bool { - struct ColumnRefVisitor<'a, F> { + struct ColumnRefVisitor<'a, A, F> { f: &'a mut F, keep_going: bool, + arena: &'a mut A, } - impl bool> Visitor<'_> for ColumnRefVisitor<'_, F> { + + impl Visitor<'_> for ColumnRefVisitor<'_, A, F> + where + A: MetaArena, + F: FnMut(&mut A, &ColumnRef) -> bool, + { fn visit(&mut self, expr: &ScalarExpression) -> Result<(), DatabaseError> { if self.keep_going { walk_expr(self, expr)?; @@ -412,85 +549,106 @@ impl ScalarExpression { } fn visit_column_ref(&mut self, col: &ColumnRef) -> Result<(), DatabaseError> { - self.keep_going = (self.f)(col); + self.keep_going = (self.f)(self.arena, col); Ok(()) } } - struct OutputColumnVisitor<'a, F> { + + let mut visitor = ColumnRefVisitor { + f, + keep_going: true, + arena, + }; + visitor.visit(self).unwrap(); + visitor.keep_going + } + + pub fn any_referenced_column( + &self, + arena: &PlanArena, + mut predicate: impl FnMut(&PlanArena, &ColumnRef) -> bool, + ) -> bool { + struct ColumnRefVisitor<'a, 'p, F> { f: &'a mut F, - keep_going: bool, + any: bool, + arena: &'a PlanArena<'p>, } - impl bool> Visitor<'_> for OutputColumnVisitor<'_, F> { - fn visit(&mut self, expr: &ScalarExpression) -> Result<(), DatabaseError> { - if !self.keep_going { - return Ok(()); - } - let output = expr.output_column(); - self.keep_going = (self.f)(&output); - if self.keep_going { + impl bool> Visitor<'_> for ColumnRefVisitor<'_, '_, F> { + fn visit(&mut self, expr: &ScalarExpression) -> Result<(), DatabaseError> { + if !self.any { walk_expr(self, expr)?; } Ok(()) } - } - if only_column_ref { - let mut visitor = ColumnRefVisitor { - f, - keep_going: true, - }; - visitor.visit(self).unwrap(); - visitor.keep_going - } else { - let mut visitor = OutputColumnVisitor { - f, - keep_going: true, - }; - visitor.visit(self).unwrap(); - visitor.keep_going + fn visit_column_ref(&mut self, col: &ColumnRef) -> Result<(), DatabaseError> { + self.any = (self.f)(self.arena, col); + Ok(()) + } } - } - pub fn any_referenced_column( - &self, - only_column_ref: bool, - mut predicate: impl FnMut(&ColumnRef) -> bool, - ) -> bool { - let mut found = false; - self.visit_referenced_columns(only_column_ref, &mut |column| { - found = predicate(column); - !found - }); - found + let mut visitor = ColumnRefVisitor { + f: &mut predicate, + any: false, + arena, + }; + visitor.visit(self).unwrap(); + visitor.any } pub fn all_referenced_columns( &self, - only_column_ref: bool, - mut predicate: impl FnMut(&ColumnRef) -> bool, + arena: &PlanArena, + mut predicate: impl FnMut(&PlanArena, &ColumnRef) -> bool, ) -> bool { - let mut all = true; - self.visit_referenced_columns(only_column_ref, &mut |column| { - all = predicate(column); - all - }); - all + struct ColumnRefVisitor<'a, 'p, F> { + f: &'a mut F, + all: bool, + arena: &'a PlanArena<'p>, + } + + impl bool> Visitor<'_> for ColumnRefVisitor<'_, '_, F> { + fn visit(&mut self, expr: &ScalarExpression) -> Result<(), DatabaseError> { + if self.all { + walk_expr(self, expr)?; + } + Ok(()) + } + + fn visit_column_ref(&mut self, col: &ColumnRef) -> Result<(), DatabaseError> { + self.all = (self.f)(self.arena, col); + Ok(()) + } + } + + let mut visitor = ColumnRefVisitor { + f: &mut predicate, + all: true, + arena, + }; + visitor.visit(self).unwrap(); + visitor.all } - pub fn has_table_ref_column(&self) -> bool { - struct TableRefChecker { + pub fn has_table_ref_column(&self, arena: &PlanArena) -> bool { + struct TableRefChecker<'arena, 'table> { found: bool, + arena: &'arena PlanArena<'table>, } - impl Visitor<'_> for TableRefChecker { + impl Visitor<'_> for TableRefChecker<'_, '_> { fn visit_column_ref(&mut self, col: &ColumnRef) -> Result<(), DatabaseError> { + let col = self.arena.column(*col); if col.table_name().is_some() && col.id().is_some() { self.found = true; } Ok(()) } } - let mut checker = TableRefChecker { found: false }; + let mut checker = TableRefChecker { + found: false, + arena, + }; checker.visit(self).unwrap(); checker.found } @@ -525,25 +683,31 @@ impl ScalarExpression { checker.has_agg } - pub fn output_name(&self) -> String { + fn output_name_by(&self, fn_display: &impl Fn(ColumnRef) -> N) -> String { match self { ScalarExpression::Constant(value) => format!("{value}"), - ScalarExpression::ColumnRef { column, .. } => column.full_name(), + ScalarExpression::ColumnRef { column, .. } => format!("{}", fn_display(*column)), ScalarExpression::Alias { alias, expr } => match alias { AliasType::Name(alias) => alias.to_string(), AliasType::Expr(alias_expr) => { - format!("({}) as ({})", expr, alias_expr.output_name()) + format!( + "({}) as ({})", + expr.output_name_by(fn_display), + alias_expr.output_name_by(fn_display) + ) } }, ScalarExpression::TypeCast { expr, ty, .. } => { - format!("cast ({} as {})", expr.output_name(), ty) + format!("cast ({} as {})", expr.output_name_by(fn_display), ty) } ScalarExpression::IsNull { expr, negated } => { let suffix = if *negated { "is not null" } else { "is null" }; - format!("{} {}", expr.output_name(), suffix) + format!("{} {}", expr.output_name_by(fn_display), suffix) + } + ScalarExpression::Unary { expr, op, .. } => { + format!("{}{}", op, expr.output_name_by(fn_display)) } - ScalarExpression::Unary { expr, op, .. } => format!("{}{}", op, expr.output_name()), ScalarExpression::Binary { left_expr, right_expr, @@ -551,9 +715,9 @@ impl ScalarExpression { .. } => format!( "({} {} {})", - left_expr.output_name(), + left_expr.output_name_by(fn_display), op, - right_expr.output_name(), + right_expr.output_name_by(fn_display), ), ScalarExpression::AggCall { args, @@ -561,7 +725,10 @@ impl ScalarExpression { distinct, .. } => { - let args_str = args.iter().map(|expr| expr.output_name()).join(", "); + let args_str = args + .iter() + .map(|expr| expr.output_name_by(fn_display)) + .join(", "); let op = |allow_distinct, distinct| { if allow_distinct && distinct { "distinct " @@ -581,9 +748,17 @@ impl ScalarExpression { negated, expr, } => { - let args_string = args.iter().map(|arg| arg.output_name()).join(", "); + let args_string = args + .iter() + .map(|arg| arg.output_name_by(fn_display)) + .join(", "); let op_string = if *negated { "not in" } else { "in" }; - format!("{} {} ({})", expr.output_name(), op_string, args_string) + format!( + "{} {} ({})", + expr.output_name_by(fn_display), + op_string, + args_string + ) } ScalarExpression::Between { expr, @@ -594,10 +769,10 @@ impl ScalarExpression { let op_string = if *negated { "not between" } else { "between" }; format!( "{} {} [{}, {}]", - expr.output_name(), + expr.output_name_by(fn_display), op_string, - left_expr.output_name(), - right_expr.output_name() + left_expr.output_name_by(fn_display), + right_expr.output_name_by(fn_display) ) } ScalarExpression::SubString { @@ -608,13 +783,13 @@ impl ScalarExpression { let op = |tag: &str, num_expr: &Option>| { num_expr .as_ref() - .map(|expr| format!(", {}: {}", tag, expr.output_name())) + .map(|expr| format!(", {}: {}", tag, expr.output_name_by(fn_display))) .unwrap_or_default() }; format!( "substring({}{}{})", - expr.output_name(), + expr.output_name_by(fn_display), op("from", from_expr), op("for", for_expr), ) @@ -622,8 +797,8 @@ impl ScalarExpression { ScalarExpression::Position { expr, in_expr } => { format!( "position({} in {})", - expr.output_name(), - in_expr.output_name() + expr.output_name_by(fn_display), + in_expr.output_name_by(fn_display) ) } ScalarExpression::Trim { @@ -634,8 +809,8 @@ impl ScalarExpression { let trim_what_str = { trim_what_expr .as_ref() - .map(|expr| expr.output_name()) - .unwrap_or(" ".to_string()) + .map(|expr| expr.output_name_by(fn_display)) + .unwrap_or_else(|| " ".to_string()) }; let trim_where_str = match trim_where { Some(TrimWhereField::Both) => format!("both '{trim_what_str}' from"), @@ -649,20 +824,33 @@ impl ScalarExpression { } } }; - format!("trim({} {})", trim_where_str, expr.output_name()) + format!( + "trim({} {})", + trim_where_str, + expr.output_name_by(fn_display) + ) } ScalarExpression::Empty => unreachable!(), ScalarExpression::Tuple(args) => { - let args_str = args.iter().map(|expr| expr.output_name()).join(", "); + let args_str = args + .iter() + .map(|expr| expr.output_name_by(fn_display)) + .join(", "); format!("({args_str})") } ScalarExpression::ScalaFunction(ScalarFunction { args, inner }) => { - let args_str = args.iter().map(|expr| expr.output_name()).join(", "); + let args_str = args + .iter() + .map(|expr| expr.output_name_by(fn_display)) + .join(", "); format!("{}({})", inner.summary().name, args_str) } - ScalarExpression::TableFunction(TableFunction { args, inner }) => { - let args_str = args.iter().map(|expr| expr.output_name()).join(", "); - format!("{}({})", inner.summary().name, args_str) + ScalarExpression::TableFunction(TableFunction { args, catalog }) => { + let args_str = args + .iter() + .map(|expr| expr.output_name_by(fn_display)) + .join(", "); + format!("{}({})", catalog.inner.summary().name, args_str) } ScalarExpression::If { condition, @@ -670,24 +858,40 @@ impl ScalarExpression { right_expr, .. } => { - format!("if {condition} ({left_expr}, {right_expr})") + format!( + "if {} ({}, {})", + condition.output_name_by(fn_display), + left_expr.output_name_by(fn_display), + right_expr.output_name_by(fn_display) + ) } ScalarExpression::IfNull { left_expr, right_expr, .. } => { - format!("ifnull({left_expr}, {right_expr})") + format!( + "ifnull({}, {})", + left_expr.output_name_by(fn_display), + right_expr.output_name_by(fn_display) + ) } ScalarExpression::NullIf { left_expr, right_expr, .. } => { - format!("ifnull({left_expr}, {right_expr})") + format!( + "ifnull({}, {})", + left_expr.output_name_by(fn_display), + right_expr.output_name_by(fn_display) + ) } ScalarExpression::Coalesce { exprs, .. } => { - let exprs_str = exprs.iter().map(|expr| expr.output_name()).join(", "); + let exprs_str = exprs + .iter() + .map(|expr| expr.output_name_by(fn_display)) + .join(", "); format!("coalesce({exprs_str})") } ScalarExpression::CaseWhen { @@ -698,12 +902,18 @@ impl ScalarExpression { } => { let op = |tag: &str, expr: &Option>| { expr.as_ref() - .map(|expr| format!("{}{} ", tag, expr.output_name())) + .map(|expr| format!("{}{} ", tag, expr.output_name_by(fn_display))) .unwrap_or_default() }; let expr_pairs_str = expr_pairs .iter() - .map(|(when_expr, then_expr)| format!("when {when_expr} then {then_expr}")) + .map(|(when_expr, then_expr)| { + format!( + "when {} then {}", + when_expr.output_name_by(fn_display), + then_expr.output_name_by(fn_display) + ) + }) .join(" "); format!( @@ -716,19 +926,28 @@ impl ScalarExpression { } } - pub fn output_column(&self) -> ColumnRef { + pub fn output_name(&self, arena: &PlanArena) -> String { + self.output_name_by(&|column| arena.column(column).full_name()) + } + + pub fn output_column_ref(&self, arena: &mut PlanArena) -> ColumnRef { match self { - ScalarExpression::ColumnRef { column, .. } => column.clone(), + ScalarExpression::ColumnRef { column, .. } => *column, ScalarExpression::Alias { alias: AliasType::Expr(expr), .. - } => expr.output_column(), - _ => ColumnRef::from(ColumnCatalog::new( - self.output_name(), - true, - // SAFETY: default expr must not be [`ScalarExpression::ColumnRef`] - ColumnDesc::new(self.return_type().into_owned(), None, false, None).unwrap(), - )), + } => expr.output_column_ref(arena), + _ => { + let output_name = self.output_name(arena); + let return_type = self.return_type(arena).into_owned(); + let column = ColumnCatalog::new( + output_name, + true, + // SAFETY: default expr must not be [`ScalarExpression::ColumnRef`] + ColumnDesc::new(return_type, None, false, None).unwrap(), + ); + arena.alloc_column(column) + } } } } @@ -740,19 +959,6 @@ pub enum UnaryOperator { Not, } -impl TryFrom for UnaryOperator { - type Error = DatabaseError; - - fn try_from(value: SqlUnaryOperator) -> Result { - match value { - SqlUnaryOperator::Plus => Ok(UnaryOperator::Plus), - SqlUnaryOperator::Minus => Ok(UnaryOperator::Minus), - SqlUnaryOperator::Not => Ok(UnaryOperator::Not), - op => Err(DatabaseError::UnsupportedStmt(format!("{op}"))), - } - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, ReferenceSerialization)] pub enum BinaryOperator { Plus, @@ -779,7 +985,7 @@ pub enum BinaryOperator { impl fmt::Display for ScalarExpression { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "{}", self.output_name()) + write!(f, "{}", self.output_name_by(&|column| column)) } } @@ -830,34 +1036,9 @@ impl fmt::Display for UnaryOperator { } } -impl TryFrom for BinaryOperator { - type Error = DatabaseError; - - fn try_from(value: SqlBinaryOperator) -> Result { - match value { - SqlBinaryOperator::Plus => Ok(BinaryOperator::Plus), - SqlBinaryOperator::Minus => Ok(BinaryOperator::Minus), - SqlBinaryOperator::Multiply => Ok(BinaryOperator::Multiply), - SqlBinaryOperator::Divide => Ok(BinaryOperator::Divide), - SqlBinaryOperator::Modulo => Ok(BinaryOperator::Modulo), - SqlBinaryOperator::StringConcat => Ok(BinaryOperator::StringConcat), - SqlBinaryOperator::Gt => Ok(BinaryOperator::Gt), - SqlBinaryOperator::Lt => Ok(BinaryOperator::Lt), - SqlBinaryOperator::GtEq => Ok(BinaryOperator::GtEq), - SqlBinaryOperator::LtEq => Ok(BinaryOperator::LtEq), - SqlBinaryOperator::Spaceship => Ok(BinaryOperator::Spaceship), - SqlBinaryOperator::Eq => Ok(BinaryOperator::Eq), - SqlBinaryOperator::NotEq => Ok(BinaryOperator::NotEq), - SqlBinaryOperator::And => Ok(BinaryOperator::And), - SqlBinaryOperator::Or => Ok(BinaryOperator::Or), - op => Err(DatabaseError::UnsupportedStmt(format!("{op}"))), - } - } -} - #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::db::test::build_table; use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; @@ -866,30 +1047,31 @@ mod test { ArcScalarFunctionImpl, ScalarFunction, ScalarFunctionImpl, }; use crate::expression::function::table::{ - ArcTableFunctionImpl, TableFunction, TableFunctionImpl, + ArcTableFunctionImpl, TableFunction, TableFunctionCatalog, TableFunctionImpl, }; use crate::expression::TrimWhereField; use crate::expression::{AliasType, BinaryOperator, ScalarExpression, UnaryOperator}; use crate::function::current_date::CurrentDate; use crate::function::numbers::Numbers; + use crate::planner::{PlanArena, TableArenaCell}; use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; - use crate::storage::rocksdb::{RocksStorage, RocksTransaction}; + use crate::storage::rocksdb::RocksStorage; + use crate::storage::rocksdb::RocksTransaction; use crate::storage::{Storage, Transaction}; use crate::types::evaluator::{binary_create, cast_create, unary_create}; use crate::types::value::{DataValue, Utf8Type}; use crate::types::CharLengthUnits; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; use std::borrow::Cow; - use std::hash::RandomState; use std::io::{Cursor, Seek, SeekFrom}; - use std::sync::Arc; use tempfile::TempDir; #[test] fn test_eq_ignore_colref_pos() -> Result<(), DatabaseError> { + let table_arena = TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); let left = ScalarExpression::column_expr( - ColumnRef::from(ColumnCatalog::new( + arena.alloc_column(ColumnCatalog::new( "c1".to_string(), false, ColumnDesc::new(LogicalType::Integer, None, false, None)?, @@ -897,7 +1079,7 @@ mod test { 0, ); let right = ScalarExpression::column_expr( - ColumnRef::from(ColumnCatalog::new( + arena.alloc_column(ColumnCatalog::new( "c1".to_string(), true, ColumnDesc::new(LogicalType::Bigint, None, false, None)?, @@ -905,7 +1087,7 @@ mod test { 2, ); let different = ScalarExpression::column_expr( - ColumnRef::from(ColumnCatalog::new( + arena.alloc_column(ColumnCatalog::new( "c2".to_string(), false, ColumnDesc::new(LogicalType::Integer, None, false, None)?, @@ -913,8 +1095,8 @@ mod test { 0, ); - assert!(left.eq_ignore_colref_pos(&right)); - assert!(!left.eq_ignore_colref_pos(&different)); + assert!(left.eq_ignore_colref_pos(&right, &arena)); + assert!(!left.eq_ignore_colref_pos(&different, &arena)); Ok(()) } @@ -925,13 +1107,15 @@ mod test { expr: ScalarExpression, drive: Option<&ReferenceDecodeContext<'_, RocksTransaction>>, reference_tables: &mut ReferenceTables, + arena: &mut PlanArena, ) -> Result<(), DatabaseError> { - expr.encode(cursor, false, reference_tables)?; + expr.encode(cursor, false, reference_tables, arena)?; cursor.seek(SeekFrom::Start(0))?; - assert_eq!( - ScalarExpression::decode(cursor, drive, reference_tables)?, - expr + let decoded = ScalarExpression::decode(cursor, drive, reference_tables, arena)?; + assert!( + decoded.eq_ignore_colref_pos(&expr, arena), + "decoded expression does not match: decoded={decoded:?}, expected={expr:?}", ); cursor.seek(SeekFrom::Start(0))?; @@ -941,22 +1125,37 @@ mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; let mut transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let mut table_cache = crate::storage::TableCache::default(); + let table_arena = TableArenaCell::default(); let mut scala_functions = ScalaFunctions::default(); let current_date = CurrentDate::new(); scala_functions.insert(current_date.summary().clone(), current_date); let mut table_functions = TableFunctions::default(); let numbers = Numbers::new(); - table_functions.insert(numbers.summary().clone(), numbers); - build_table(&table_cache, &mut transaction)?; + let mut schema = Vec::new(); + numbers.output_schema_into( + &numbers.summary().name, + table_arena.borrow_mut(), + &mut schema, + ); + table_functions.insert( + numbers.summary().clone(), + TableFunctionCatalog { + schema, + inner: ArcTableFunctionImpl(numbers), + }, + ); + let mut plan_arena = PlanArena::new(&table_arena); + build_table(&mut table_cache, &mut transaction, &mut plan_arena)?; + let mut plan_arena = PlanArena::new(&table_arena); let mut cursor = Cursor::new(Vec::new()); let mut reference_tables = ReferenceTables::new(); - let c3_column_id = { + let c3_column = { let table = transaction .table(&table_cache, "t1".to_string().into())? .unwrap(); - *table.get_column_id_by_name("c3").unwrap() + table.get_column_by_name("c3").unwrap() }; let context = ReferenceDecodeContext::with_functions( Some((&transaction, &table_cache)), @@ -969,12 +1168,14 @@ mod test { ScalarExpression::Constant(DataValue::Null), Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, ScalarExpression::Constant(DataValue::Int32(42)), Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -985,44 +1186,28 @@ mod test { }), Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, - ScalarExpression::column_expr( - ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c3".to_string(), - relation: ColumnRelation::Table { - column_id: c3_column_id, - table_name: "t1".to_string().into(), - is_temp: false, - }, - }, - false, - ColumnDesc::new(LogicalType::Integer, None, false, None)?, - false, - )), - 0, - ), + ScalarExpression::column_expr(c3_column, 0), Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, ScalarExpression::column_expr( - ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c4".to_string(), - relation: ColumnRelation::None, - }, + plan_arena.alloc_column(ColumnCatalog::new( + "c4".to_string(), false, ColumnDesc::new(LogicalType::Boolean, None, false, None)?, - false, )), 1, ), Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1032,6 +1217,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1041,6 +1227,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1054,6 +1241,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1063,6 +1251,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1077,6 +1266,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1088,6 +1278,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1102,6 +1293,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1114,6 +1306,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1125,6 +1318,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1135,6 +1329,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1146,6 +1341,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1156,6 +1352,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1166,6 +1363,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1176,6 +1374,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1185,6 +1384,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1195,6 +1395,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1205,6 +1406,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1215,18 +1417,21 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, ScalarExpression::Empty, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, ScalarExpression::Tuple(vec![ScalarExpression::Empty]), Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1236,15 +1441,20 @@ mod test { }), Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, ScalarExpression::TableFunction(TableFunction { args: vec![ScalarExpression::Empty], - inner: ArcTableFunctionImpl(Numbers::new()), + catalog: TableFunctionCatalog { + schema: Vec::new(), + inner: ArcTableFunctionImpl(Numbers::new()), + }, }), Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1256,6 +1466,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1266,6 +1477,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1276,6 +1488,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1285,6 +1498,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1296,6 +1510,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1307,6 +1522,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; fn_assert( &mut cursor, @@ -1318,6 +1534,7 @@ mod test { }, Some(&context), &mut reference_tables, + &mut plan_arena, )?; Ok(()) diff --git a/src/expression/range_detacher.rs b/src/expression/range_detacher.rs index c43f5774..9a60a0ad 100644 --- a/src/expression/range_detacher.rs +++ b/src/expression/range_detacher.rs @@ -15,6 +15,7 @@ use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::{BinaryOperator, ScalarExpression}; +use crate::planner::PlanArena; use crate::types::value::DataValue; use crate::types::ColumnId; use itertools::Itertools; @@ -179,16 +180,22 @@ impl Range { } } -pub struct RangeDetacher<'a> { +pub struct RangeDetacher<'a, 'p> { table_name: &'a str, column_id: &'a ColumnId, + arena: &'a PlanArena<'p>, } -impl<'a> RangeDetacher<'a> { - pub(crate) fn new(table_name: &'a str, column_id: &'a ColumnId) -> Self { +impl<'a, 'p> RangeDetacher<'a, 'p> { + pub(crate) fn new( + table_name: &'a str, + column_id: &'a ColumnId, + arena: &'a PlanArena<'p>, + ) -> Self { Self { table_name, column_id, + arena, } } @@ -228,6 +235,7 @@ impl<'a> RangeDetacher<'a> { } ScalarExpression::IsNull { expr, negated, .. } => match expr.as_ref() { ScalarExpression::ColumnRef { column, .. } => { + let column = self.arena.column(*column); if let (Some(col_id), Some(col_table)) = (column.id(), column.table_name()) { if &col_id == self.column_id && col_table.as_ref() == self.table_name { return if *negated { @@ -684,9 +692,11 @@ impl<'a> RangeDetacher<'a> { } } - fn _is_belong(table_name: &str, col: &ColumnRef) -> bool { + fn _is_belong(&self, col: ColumnRef) -> bool { + let col = self.arena.column(col); matches!( - col.table_name().map(|name| table_name == name.as_ref()), + col.table_name() + .map(|name| self.table_name == name.as_ref()), Some(true) ) } @@ -725,10 +735,11 @@ impl<'a> RangeDetacher<'a> { mut val: DataValue, is_flip: bool, ) -> Result, DatabaseError> { - if !Self::_is_belong(self.table_name, &col) || col.id() != Some(*self.column_id) { + let column = self.arena.column(col); + if !self._is_belong(col) || column.id() != Some(*self.column_id) { return Ok(None); } - val = val.cast(col.datatype())?; + val = val.cast(column.datatype())?; if is_flip { op = match op { BinaryOperator::Gt => BinaryOperator::Lt, @@ -813,13 +824,15 @@ mod test { use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::Operator; use crate::planner::LogicalPlan; - use crate::storage::rocksdb::RocksTransaction; use crate::types::evaluator::binary_create; use crate::types::value::DataValue; use crate::types::LogicalType; use std::ops::Bound; - fn plan_filter(plan: LogicalPlan) -> Result, DatabaseError> { + fn plan_filter( + plan: LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result, DatabaseError> { let pipeline = HepOptimizerPipeline::builder() .before_batch( "test_simplify_filter".to_string(), @@ -827,9 +840,7 @@ mod test { vec![NormalizationRuleImpl::SimplifyFilter], ) .build(); - let best_plan = pipeline - .instantiate(plan) - .find_best::(None)?; + let best_plan = pipeline.instantiate(plan).find_best(None, arena)?; if let Operator::Filter(filter_op) = best_plan.childrens.pop_only().operator { Ok(Some(filter_op)) } else { @@ -840,10 +851,11 @@ mod test { #[test] fn test_detach_ideal_cases() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; + let mut plan_arena = crate::planner::PlanArena::new(&table_state.table_arena); { let plan = table_state.plan("select * from t1 where c1 = 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = 1 => {}", range); @@ -851,8 +863,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 = 1.0")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = 1.0 => {}", range); @@ -860,16 +872,16 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 != 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)?; println!("c1 != 1 => {:#?}", range); assert_eq!(range, None) } { let plan = table_state.plan("select * from t1 where c1 > 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 > 1 => c1: {}", range); @@ -883,8 +895,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 >= 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 >= 1 => c1: {}", range); @@ -898,8 +910,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 < 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 < 1 => c1: {}", range); @@ -913,8 +925,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 <= 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 <= 1 => c1: {}", range); @@ -928,8 +940,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 < 1 and c1 >= 0")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 < 1 and c1 >= 0 => c1: {}", range); @@ -943,8 +955,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 < 1 or c1 >= 0")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 < 1 or c1 >= 0 => c1: {}", range); @@ -959,8 +971,8 @@ mod test { // and & or { let plan = table_state.plan("select * from t1 where c1 = 1 and c1 = 0")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = 1 and c1 = 0 => c1: {}", range); @@ -968,8 +980,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 = 1 or c1 = 0")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = 1 or c1 = 0 => c1: {}", range); @@ -983,8 +995,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 = 1 and c1 = 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = 1 and c1 = 1 => c1: {}", range); @@ -992,8 +1004,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 = 1 or c1 = 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = 1 or c1 = 1 => c1: {}", range); @@ -1002,8 +1014,8 @@ mod test { { let plan = table_state.plan("select * from t1 where c1 > 1 and c1 = 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 > 1 and c1 = 1 => c1: {}", range); @@ -1011,8 +1023,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 >= 1 and c1 = 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 >= 1 and c1 = 1 => c1: {}", range); @@ -1020,8 +1032,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 > 1 or c1 = 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 > 1 or c1 = 1 => c1: {}", range); @@ -1035,8 +1047,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 >= 1 or c1 = 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 >= 1 or c1 = 1 => c1: {}", range); @@ -1052,8 +1064,8 @@ mod test { { let plan = table_state .plan("select * from t1 where (c1 > 0 and c1 < 3) and (c1 > 1 and c1 < 4)")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!( @@ -1071,8 +1083,8 @@ mod test { { let plan = table_state .plan("select * from t1 where (c1 > 0 and c1 < 3) or (c1 > 1 and c1 < 4)")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!( @@ -1092,8 +1104,8 @@ mod test { let plan = table_state.plan( "select * from t1 where ((c1 > 0 and c1 < 3) and (c1 > 1 and c1 < 4)) and c1 = 0", )?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!( @@ -1106,8 +1118,8 @@ mod test { let plan = table_state.plan( "select * from t1 where ((c1 > 0 and c1 < 3) or (c1 > 1 and c1 < 4)) and c1 = 0", )?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!( @@ -1120,8 +1132,8 @@ mod test { let plan = table_state.plan( "select * from t1 where ((c1 > 0 and c1 < 3) and (c1 > 1 and c1 < 4)) or c1 = 0", )?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!( @@ -1143,8 +1155,8 @@ mod test { let plan = table_state.plan( "select * from t1 where ((c1 > 0 and c1 < 3) or (c1 > 1 and c1 < 4)) or c1 = 0", )?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!( @@ -1162,8 +1174,8 @@ mod test { { let plan = table_state.plan("select * from t1 where (((c1 > 0 and c1 < 3) and (c1 > 1 and c1 < 4)) and c1 = 0) and (c1 >= 0 and c1 <= 2)")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("(((c1 > 0 and c1 < 3) and (c1 > 1 and c1 < 4)) and c1 = 0) and (c1 >= 0 and c1 <= 2) => c1: {}", range); @@ -1171,8 +1183,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where (((c1 > 0 and c1 < 3) and (c1 > 1 and c1 < 4)) and c1 = 0) or (c1 >= 0 and c1 <= 2)")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("(((c1 > 0 and c1 < 3) and (c1 > 1 and c1 < 4)) and c1 = 0) or (c1 >= 0 and c1 <= 2) => c1: {}", range); @@ -1187,8 +1199,8 @@ mod test { // ranges and ranges { let plan = table_state.plan("select * from t1 where ((c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) and ((c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5))")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("((c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) and ((c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5)) => c1: {}", range); @@ -1208,8 +1220,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where ((c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) or ((c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5))")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("((c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) or ((c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5)) => c1: {}", range); @@ -1230,8 +1242,8 @@ mod test { // empty { let plan = table_state.plan("select * from t1 where true")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)?; println!("empty => c1: {:#?}", range); assert_eq!(range, None) @@ -1239,24 +1251,24 @@ mod test { // other column { let plan = table_state.plan("select * from t1 where c2 = 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)?; println!("c2 = 1 => c1: {:#?}", range); assert_eq!(range, None) } { let plan = table_state.plan("select * from t1 where c1 > 1 or c2 > 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)?; println!("c1 > 1 or c2 > 1 => c1: {:#?}", range); assert_eq!(range, None) } { let plan = table_state.plan("select * from t1 where c1 > c2 or c2 > 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)?; println!("c1 > c2 or c2 > 1 => c1: {:#?}", range); assert_eq!(range, None) @@ -1266,8 +1278,8 @@ mod test { let plan = table_state.plan( "select * from t1 where c1 = 5 or (c1 > 5 and (c1 > 6 or c1 < 8) and c1 < 12)", )?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!( @@ -1287,8 +1299,8 @@ mod test { let plan = table_state.plan( "select * from t1 where ((c2 >= -8 and -4 >= c1) or (c1 >= 0 and 5 > c2)) and ((c2 > 0 and c1 <= 1) or (c1 > -8 and c2 < -6))", )?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!( @@ -1316,10 +1328,12 @@ mod test { #[test] fn test_detach_only_conjunction_can_keep_partial_range() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let detach_c1 = |sql: &str| -> Result, DatabaseError> { + let mut plan_arena = crate::planner::PlanArena::new(&table_state.table_arena); + let mut detach_c1 = |sql: &str| -> Result, DatabaseError> { let plan = table_state.plan(sql)?; - let op = plan_filter(plan)?.unwrap(); - RangeDetacher::new("t1", table_state.column_id_by_name("c1")).detach(&op.predicate) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) + .detach(&op.predicate) }; assert_eq!( @@ -1358,11 +1372,12 @@ mod test { #[test] fn test_detach_null_cases() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; + let mut plan_arena = crate::planner::PlanArena::new(&table_state.table_arena); // eq { let plan = table_state.plan("select * from t1 where c1 = null")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = null => c1: {}", range); @@ -1370,8 +1385,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 = null or c1 = 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = null or c1 = 1 => c1: {}", range); @@ -1385,8 +1400,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 = null or c1 < 5")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = null or c1 < 5 => c1: {}", range); @@ -1401,8 +1416,8 @@ mod test { { let plan = table_state.plan("select * from t1 where c1 = null or (c1 > 1 and c1 < 5)")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = null or (c1 > 1 and c1 < 5) => c1: {}", range); @@ -1419,8 +1434,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where c1 = null and c1 < 5")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = null and c1 < 5 => c1: {}", range); @@ -1429,8 +1444,8 @@ mod test { { let plan = table_state.plan("select * from t1 where c1 = null and (c1 > 1 and c1 < 5)")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 = null and (c1 > 1 and c1 < 5) => c1: {}", range); @@ -1439,24 +1454,24 @@ mod test { // noteq { let plan = table_state.plan("select * from t1 where c1 != null")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)?; println!("c1 != null => c1: {:#?}", range); assert_eq!(range, None) } { let plan = table_state.plan("select * from t1 where c1 = null or c1 != 1")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)?; println!("c1 = null or c1 != 1 => c1: {:#?}", range); assert_eq!(range, None) } { let plan = table_state.plan("select * from t1 where c1 != null or c1 < 5")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)?; println!("c1 != null or c1 < 5 => c1: {:#?}", range); assert_eq!(range, None) @@ -1464,16 +1479,16 @@ mod test { { let plan = table_state.plan("select * from t1 where c1 != null or (c1 > 1 and c1 < 5)")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)?; println!("c1 != null or (c1 > 1 and c1 < 5) => c1: {:#?}", range); assert_eq!(range, None) } { let plan = table_state.plan("select * from t1 where c1 != null and c1 < 5")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 != null and c1 < 5 => c1: {}", range); @@ -1488,8 +1503,8 @@ mod test { { let plan = table_state.plan("select * from t1 where c1 != null and (c1 > 1 and c1 < 5)")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("c1 != null and (c1 > 1 and c1 < 5) => c1: {}", range); @@ -1503,8 +1518,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where (c1 = null or (c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) or ((c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5))")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("(c1 = null or (c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) or ((c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5)) => c1: {}", range); @@ -1525,8 +1540,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where ((c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) or (c1 = null or (c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5))")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("((c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) or (c1 = null or (c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5)) => c1: {}", range); @@ -1547,8 +1562,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where (c1 = null or (c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) and ((c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5))")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("(c1 = null or (c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) and ((c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5)) => c1: {}", range); @@ -1568,8 +1583,8 @@ mod test { } { let plan = table_state.plan("select * from t1 where ((c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) and (c1 = null or (c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5))")?; - let op = plan_filter(plan)?.unwrap(); - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let op = plan_filter(plan, &mut plan_arena)?.unwrap(); + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &plan_arena) .detach(&op.predicate)? .unwrap(); println!("((c1 < 2 and c1 > 0) or (c1 < 6 and c1 > 4)) and (c1 = null or (c1 < 3 and c1 > 1) or (c1 < 7 and c1 > 5)) => c1: {}", range); diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 52fee043..537b1244 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -16,6 +16,7 @@ use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::visitor_mut::{walk_mut_expr, VisitorMut}; use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; +use crate::planner::PlanArena; use crate::types::evaluator::{binary_create, unary_create}; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -44,9 +45,17 @@ struct ReplaceUnary { ty: LogicalType, } -pub struct ConstantCalculator; +pub struct ConstantCalculator<'a, 'p> { + arena: &'a PlanArena<'p>, +} + +impl<'a, 'p> ConstantCalculator<'a, 'p> { + pub fn new(arena: &'a PlanArena<'p>) -> Self { + Self { arena } + } +} -impl VisitorMut<'_> for ConstantCalculator { +impl VisitorMut<'_> for ConstantCalculator<'_, '_> { fn visit(&mut self, expr: &'_ mut ScalarExpression) -> Result<(), DatabaseError> { match expr { ScalarExpression::Unary { @@ -72,8 +81,8 @@ impl VisitorMut<'_> for ConstantCalculator { right_expr, .. } => { - let left_ty = left_expr.return_type(); - let right_ty = right_expr.return_type(); + let left_ty = left_expr.return_type(self.arena); + let right_ty = right_expr.return_type(self.arena); let ty = LogicalType::max_logical_type(&left_ty, &right_ty)?.into_owned(); self.visit(left_expr)?; self.visit(right_expr)?; @@ -534,13 +543,6 @@ impl Simplify { }), ); } - - fn _is_belong(table_name: &str, col: &ColumnRef) -> bool { - matches!( - col.table_name().map(|name| table_name == name.as_ref()), - Some(true) - ) - } } impl ScalarExpression { @@ -598,7 +600,7 @@ impl ScalarExpression { pub(crate) fn unpack_bound_col(&self, is_deep: bool) -> Option<(ColumnRef, usize)> { match self { - ScalarExpression::ColumnRef { column, position } => Some((column.clone(), *position)), + ScalarExpression::ColumnRef { column, position } => Some((*column, *position)), ScalarExpression::Alias { expr, .. } => expr.unpack_bound_col(is_deep), ScalarExpression::Unary { expr, .. } => expr.unpack_bound_col(is_deep), ScalarExpression::Binary { diff --git a/src/expression/visitor.rs b/src/expression/visitor.rs index 562077e2..3c2047c9 100644 --- a/src/expression/visitor.rs +++ b/src/expression/visitor.rs @@ -19,7 +19,7 @@ use crate::expression::function::scala::ScalarFunction; use crate::expression::function::table::TableFunction; use crate::expression::TrimWhereField; use crate::expression::{AliasType, BinaryOperator, ScalarExpression, UnaryOperator}; -use crate::types::evaluator::{BinaryEvaluatorBox, CastEvaluatorBox, UnaryEvaluatorBox}; +use crate::types::evaluator::{BinaryEvaluatorRef, CastEvaluatorRef, UnaryEvaluatorRef}; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -51,7 +51,7 @@ pub trait Visitor<'a>: Sized { &mut self, expr: &'a ScalarExpression, _ty: &'a LogicalType, - _evaluator: Option<&'a CastEvaluatorBox>, + _evaluator: Option<&'a CastEvaluatorRef>, ) -> Result<(), DatabaseError> { self.visit(expr) } @@ -68,7 +68,7 @@ pub trait Visitor<'a>: Sized { &mut self, _op: &'a UnaryOperator, expr: &'a ScalarExpression, - _evaluator: Option<&'a UnaryEvaluatorBox>, + _evaluator: Option<&'a UnaryEvaluatorRef>, _ty: &'a LogicalType, ) -> Result<(), DatabaseError> { self.visit(expr) @@ -79,7 +79,7 @@ pub trait Visitor<'a>: Sized { _op: &'a BinaryOperator, left_expr: &'a ScalarExpression, right_expr: &'a ScalarExpression, - _evaluator: Option<&'a BinaryEvaluatorBox>, + _evaluator: Option<&'a BinaryEvaluatorRef>, _ty: &'a LogicalType, ) -> Result<(), DatabaseError> { self.visit(left_expr)?; diff --git a/src/expression/visitor_mut.rs b/src/expression/visitor_mut.rs index e15b20b6..f00377ed 100644 --- a/src/expression/visitor_mut.rs +++ b/src/expression/visitor_mut.rs @@ -19,7 +19,7 @@ use crate::expression::function::scala::ScalarFunction; use crate::expression::function::table::TableFunction; use crate::expression::TrimWhereField; use crate::expression::{AliasType, BinaryOperator, ScalarExpression, UnaryOperator}; -use crate::types::evaluator::{BinaryEvaluatorBox, CastEvaluatorBox, UnaryEvaluatorBox}; +use crate::types::evaluator::{BinaryEvaluatorRef, CastEvaluatorRef, UnaryEvaluatorRef}; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -74,7 +74,7 @@ pub trait VisitorMut<'a>: Sized { &mut self, expr: &'a mut ScalarExpression, _ty: &'a mut LogicalType, - _evaluator: &'a mut Option, + _evaluator: &'a mut Option, ) -> Result<(), DatabaseError> { self.visit(expr) } @@ -91,7 +91,7 @@ pub trait VisitorMut<'a>: Sized { &mut self, _op: &'a mut UnaryOperator, expr: &'a mut ScalarExpression, - _evaluator: &'a mut Option, + _evaluator: &'a mut Option, _ty: &'a mut LogicalType, ) -> Result<(), DatabaseError> { self.visit(expr) @@ -102,7 +102,7 @@ pub trait VisitorMut<'a>: Sized { _op: &'a mut BinaryOperator, left_expr: &'a mut ScalarExpression, right_expr: &'a mut ScalarExpression, - _evaluator: &'a mut Option, + _evaluator: &'a mut Option, _ty: &'a mut LogicalType, ) -> Result<(), DatabaseError> { self.visit(left_expr)?; diff --git a/src/function/char_length.rs b/src/function/char_length.rs index 2940f769..4a5190da 100644 --- a/src/function/char_length.rs +++ b/src/function/char_length.rs @@ -21,11 +21,9 @@ use crate::types::tuple::TupleLike; use crate::types::value::DataValue; use crate::types::CharLengthUnits; use crate::types::LogicalType; -use serde::Deserialize; -use serde::Serialize; use std::sync::Arc; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] pub(crate) struct CharLength { summary: FunctionSummary, } diff --git a/src/function/current_date.rs b/src/function/current_date.rs index 99bc308b..3c9cd705 100644 --- a/src/function/current_date.rs +++ b/src/function/current_date.rs @@ -21,11 +21,9 @@ use crate::types::tuple::TupleLike; use crate::types::value::DataValue; use crate::types::LogicalType; use chrono::{Datelike, Local}; -use serde::Deserialize; -use serde::Serialize; use std::sync::Arc; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] pub(crate) struct CurrentDate { summary: FunctionSummary, } diff --git a/src/function/current_timestamp.rs b/src/function/current_timestamp.rs index 3fac94c7..1d7563c3 100644 --- a/src/function/current_timestamp.rs +++ b/src/function/current_timestamp.rs @@ -21,11 +21,9 @@ use crate::types::tuple::TupleLike; use crate::types::value::DataValue; use crate::types::LogicalType; use chrono::Utc; -use serde::Deserialize; -use serde::Serialize; use std::sync::Arc; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] pub(crate) struct CurrentTimeStamp { summary: FunctionSummary, } diff --git a/src/function/lower.rs b/src/function/lower.rs index 021e1e31..c73a1232 100644 --- a/src/function/lower.rs +++ b/src/function/lower.rs @@ -21,11 +21,9 @@ use crate::types::tuple::TupleLike; use crate::types::value::DataValue; use crate::types::CharLengthUnits; use crate::types::LogicalType; -use serde::Deserialize; -use serde::Serialize; use std::sync::Arc; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] pub(crate) struct Lower { summary: FunctionSummary, } diff --git a/src/function/mod.rs b/src/function/mod.rs index 8c72a95a..77982a03 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -13,7 +13,9 @@ // limitations under the License. pub(crate) mod char_length; +#[cfg(feature = "time")] pub(crate) mod current_date; +#[cfg(feature = "time")] pub(crate) mod current_timestamp; pub(crate) mod lower; pub(crate) mod numbers; diff --git a/src/function/numbers.rs b/src/function/numbers.rs index 4445d109..7a84abd1 100644 --- a/src/function/numbers.rs +++ b/src/function/numbers.rs @@ -14,29 +14,17 @@ use crate::catalog::ColumnCatalog; use crate::catalog::ColumnDesc; -use crate::catalog::TableCatalog; +use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::expression::function::table::TableFunctionImpl; use crate::expression::function::FunctionSummary; use crate::expression::ScalarExpression; -use crate::types::tuple::SchemaRef; +use crate::planner::TableArena; +use crate::types::tuple::Schema; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; use std::sync::Arc; -use std::sync::LazyLock; - -static NUMBERS: LazyLock = LazyLock::new(|| { - TableCatalog::new( - "numbers".to_string().into(), - vec![ColumnCatalog::new( - "number".to_string(), - true, - ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), - )], - ) - .unwrap() -}); #[derive(Debug)] pub(crate) struct Numbers { @@ -75,15 +63,23 @@ impl TableFunctionImpl for Numbers { ) } - fn output_schema(&self) -> &SchemaRef { - NUMBERS.schema_ref() - } - fn summary(&self) -> &FunctionSummary { &self.summary } - fn table(&self) -> &TableCatalog { - &NUMBERS + fn output_schema_into( + &self, + table_name: &TableName, + table_arena: &mut TableArena, + schema: &mut Schema, + ) { + schema.push(table_arena.alloc_table_column( + table_name.clone(), + ColumnCatalog::new( + "number".to_string(), + true, + ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), + ), + )); } } diff --git a/src/function/octet_length.rs b/src/function/octet_length.rs index af38e5ef..a41b779f 100644 --- a/src/function/octet_length.rs +++ b/src/function/octet_length.rs @@ -21,11 +21,9 @@ use crate::types::tuple::TupleLike; use crate::types::value::DataValue; use crate::types::CharLengthUnits; use crate::types::LogicalType; -use serde::Deserialize; -use serde::Serialize; use std::sync::Arc; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] pub(crate) struct OctetLength { summary: FunctionSummary, } diff --git a/src/function/upper.rs b/src/function/upper.rs index 417991a9..991896c1 100644 --- a/src/function/upper.rs +++ b/src/function/upper.rs @@ -21,11 +21,9 @@ use crate::types::tuple::TupleLike; use crate::types::value::DataValue; use crate::types::CharLengthUnits; use crate::types::LogicalType; -use serde::Deserialize; -use serde::Serialize; use std::sync::Arc; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] pub(crate) struct Upper { summary: FunctionSummary, } diff --git a/src/lib.rs b/src/lib.rs index 40b161cf..9d5e4235 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,13 +50,14 @@ //! [`Database::new_transaction`](db::Database::new_transaction) method. //! //! support UDF (User-Defined Function) so that users can customize internal calculation functions -//! with the [`DataBaseBuilder::register_function`](db::DataBaseBuilder::register_scala_function) +//! with [`Database::load`](db::Database::load) and [`CatalogKind`](db::CatalogKind) //! //! # Examples //! //! ```ignore -//! use kite_sql::db::{DataBaseBuilder, ResultIter}; +//! use kite_sql::db::DataBaseBuilder; //! use kite_sql::errors::DatabaseError; +//! use kite_sql::orm::OrmQueryResultExt; //! use kite_sql::Model; //! //! #[derive(Default, Debug, PartialEq, Model)] @@ -82,7 +83,15 @@ //! c2: "one".to_string(), //! })?; //! -//! for row in database.fetch::()? { +//! let rows = database +//! .bind(|ctx| { +//! ctx.from::()? +//! .filter(|e| e.column(MyStruct::c1())?.gte(1))? +//! .project_scalars((MyStruct::c1(), MyStruct::c2())) +//! })? +//! .project_tuple::<(i32, String)>(); +//! +//! for row in rows { //! println!("{:?}", row?); //! } //! database.drop_table::()?; @@ -106,6 +115,7 @@ pub mod macros; mod optimizer; #[cfg(feature = "orm")] pub mod orm; +#[cfg(feature = "parser")] pub mod parser; pub mod planner; #[cfg(all(not(target_arch = "wasm32"), feature = "python"))] @@ -113,7 +123,6 @@ pub mod python; pub mod serdes; pub mod storage; pub mod types; -pub(crate) mod utils; #[cfg(target_arch = "wasm32")] pub mod wasm; diff --git a/src/macros/mod.rs b/src/macros/mod.rs index aee17e8a..7e8f5607 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -38,14 +38,11 @@ #[macro_export] macro_rules! from_tuple { ($struct_name:ident, ($($field_name:ident : $field_type:ty => $closure:expr),+)) => { - impl From<(&::kite_sql::types::tuple::SchemaRef, ::kite_sql::types::tuple::Tuple)> for $struct_name { - fn from((schema, mut tuple): (&::kite_sql::types::tuple::SchemaRef, ::kite_sql::types::tuple::Tuple)) -> Self { - fn try_get(tuple: &mut ::kite_sql::types::tuple::Tuple, schema: &::kite_sql::types::tuple::SchemaRef, field_name: &str) -> Option<::kite_sql::types::value::DataValue> { + impl<'__kite_schema, '__kite_arena> From<(&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)> for $struct_name { + fn from((schema, mut tuple): (&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)) -> Self { + fn try_get(tuple: &mut ::kite_sql::types::tuple::Tuple, schema: &::kite_sql::types::tuple::SchemaView<'_, '_>, field_name: &str) -> Option<::kite_sql::types::value::DataValue> { let ty = ::kite_sql::types::LogicalType::type_trans::()?; - let (idx, _) = schema - .iter() - .enumerate() - .find(|(_, col)| col.name() == field_name)?; + let idx = schema.position(field_name)?; std::mem::replace(&mut tuple.values[idx], ::kite_sql::types::value::DataValue::Null).cast(&ty).ok() } @@ -72,10 +69,8 @@ macro_rules! from_tuple { /// DataValue::binary_op(&v1, &v2, &BinaryOperator::Plus) /// }); /// -/// let kite_sql = DataBaseBuilder::path("./example") -/// .register_scala_function(TestFunction::new()) -/// .build() -/// ?; +/// let mut kite_sql = DataBaseBuilder::path("./example").build()?; +/// kite_sql.load(CatalogKind::ScalarFunction(TestFunction::new()))?; /// ``` #[macro_export] macro_rules! scala_function { @@ -145,24 +140,12 @@ macro_rules! scala_function { /// ])))) as Box>>) /// })); /// -/// let kite_sql = DataBaseBuilder::path("./example") -/// .register_table_function(MyTableFunction::new()) -/// .build() -/// ?; +/// let mut kite_sql = DataBaseBuilder::path("./example").build()?; +/// kite_sql.load(CatalogKind::TableFunction(MyTableFunction::new()))?; /// ``` #[macro_export] macro_rules! table_function { ($struct_name:ident::$function_name:ident($($arg_ty:expr),*) -> [$($output_name:ident: $output_ty:expr),*] => $closure:expr) => { - static $function_name: ::std::sync::LazyLock<::kite_sql::catalog::table::TableCatalog> = ::std::sync::LazyLock::new(|| { - let mut columns = Vec::new(); - - $({ - columns.push(::kite_sql::catalog::column::ColumnCatalog::new(stringify!($output_name).to_lowercase(), true, ::kite_sql::catalog::column::ColumnDesc::new($output_ty, None, false, None).unwrap())); - })* - - ::kite_sql::catalog::table::TableCatalog::new(stringify!($function_name).to_lowercase().into(), columns).unwrap() - }); - #[derive(Debug)] pub(crate) struct $struct_name { summary: ::kite_sql::expression::function::FunctionSummary, @@ -201,16 +184,22 @@ macro_rules! table_function { }, )*) } - fn output_schema(&self) -> &::kite_sql::types::tuple::SchemaRef { - $function_name.schema_ref() - } - fn summary(&self) -> &::kite_sql::expression::function::FunctionSummary { &self.summary } - fn table(&self) -> &::kite_sql::catalog::table::TableCatalog { - &$function_name + fn output_schema_into( + &self, + table_name: &::kite_sql::catalog::table::TableName, + table_arena: &mut ::kite_sql::planner::TableArena, + schema: &mut ::kite_sql::types::tuple::Schema, + ) { + $({ + schema.push(table_arena.alloc_table_column( + table_name.clone(), + ::kite_sql::catalog::column::ColumnCatalog::new(stringify!($output_name).to_lowercase(), true, ::kite_sql::catalog::column::ColumnDesc::new($output_ty, None, false, None).unwrap()), + )); + })* } } }; diff --git a/src/optimizer/core/cm_sketch.rs b/src/optimizer/core/cm_sketch.rs index 56063875..68ede2d7 100644 --- a/src/optimizer/core/cm_sketch.rs +++ b/src/optimizer/core/cm_sketch.rs @@ -334,33 +334,39 @@ impl CountMinSketch { } impl ReferenceSerialization for CountMinSketch { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self.counters.encode(writer, is_direct, reference_tables)?; - self.offsets.encode(writer, is_direct, reference_tables)?; - self.hashers[0].encode(writer, is_direct, reference_tables)?; - self.hashers[1].encode(writer, is_direct, reference_tables)?; - self.mask.encode(writer, is_direct, reference_tables)?; - self.k_num.encode(writer, is_direct, reference_tables)?; + self.counters + .encode(writer, is_direct, reference_tables, arena)?; + self.offsets + .encode(writer, is_direct, reference_tables, arena)?; + self.hashers[0].encode(writer, is_direct, reference_tables, arena)?; + self.hashers[1].encode(writer, is_direct, reference_tables, arena)?; + self.mask + .encode(writer, is_direct, reference_tables, arena)?; + self.k_num + .encode(writer, is_direct, reference_tables, arena)?; Ok(()) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - let counters = Vec::>::decode(reader, drive, reference_tables)?; - let offsets = Vec::::decode(reader, drive, reference_tables)?; - let hasher_0 = FastHasher::decode(reader, drive, reference_tables)?; - let hasher_1 = FastHasher::decode(reader, drive, reference_tables)?; - let mask = usize::decode(reader, drive, reference_tables)?; - let k_num = usize::decode(reader, drive, reference_tables)?; + let counters = Vec::>::decode(reader, drive, reference_tables, arena)?; + let offsets = Vec::::decode(reader, drive, reference_tables, arena)?; + let hasher_0 = FastHasher::decode(reader, drive, reference_tables, arena)?; + let hasher_1 = FastHasher::decode(reader, drive, reference_tables, arena)?; + let mask = usize::decode(reader, drive, reference_tables, arena)?; + let k_num = usize::decode(reader, drive, reference_tables, arena)?; Ok(CountMinSketch { counters, diff --git a/src/optimizer/core/histogram.rs b/src/optimizer/core/histogram.rs index cef0ac4a..cfbbda7a 100644 --- a/src/optimizer/core/histogram.rs +++ b/src/optimizer/core/histogram.rs @@ -18,7 +18,7 @@ use crate::expression::range_detacher::Range; use crate::expression::BinaryOperator; use crate::optimizer::core::cm_sketch::CountMinSketch; use crate::storage::table_codec::BumpBytes; -use crate::types::evaluator::{binary_create, BinaryEvaluatorBox}; +use crate::types::evaluator::{binary_create, BinaryEvaluatorRef}; use crate::types::index::{IndexId, IndexMeta}; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -45,10 +45,10 @@ pub struct HistogramBuilder { #[derive(Debug)] struct BoundComparator { - lt: BinaryEvaluatorBox, - lte: BinaryEvaluatorBox, - gt: BinaryEvaluatorBox, - gte: BinaryEvaluatorBox, + lt: BinaryEvaluatorRef, + lte: BinaryEvaluatorRef, + gt: BinaryEvaluatorRef, + gte: BinaryEvaluatorRef, } #[derive(Debug, Clone, PartialEq, ReferenceSerialization)] @@ -414,7 +414,7 @@ impl Histogram { &mut bucket_idxs, &mut count, sketch, - &comparator, + comparator, )?; if is_dummy { return Ok(0); diff --git a/src/optimizer/core/rule.rs b/src/optimizer/core/rule.rs index 303cb578..ffc62673 100644 --- a/src/optimizer/core/rule.rs +++ b/src/optimizer/core/rule.rs @@ -16,8 +16,7 @@ use crate::errors::DatabaseError; use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption}; -use crate::planner::LogicalPlan; -use crate::storage::Transaction; +use crate::planner::{LogicalPlan, PlanArena}; use std::cmp::Ordering; pub type BestPhysicalOption = Option<(PhysicalOption, Option)>; @@ -29,7 +28,7 @@ pub trait MatchPattern { pub trait NormalizationRule { /// Returns true when the plan tree is modified. - fn apply(&self, plan: &mut LogicalPlan) -> Result; + fn apply(&self, plan: &mut LogicalPlan, arena: &mut PlanArena) -> Result; } fn compare_costs(candidate_cost: Option, best_cost: Option) -> Ordering { @@ -56,11 +55,12 @@ pub fn keep_best_physical_option( } } -pub trait ImplementationRule: MatchPattern { +pub trait ImplementationRule: MatchPattern { fn update_best_option( &self, op: &Operator, - loader: &StatisticMetaLoader, + arena: &PlanArena, + loader: &StatisticMetaLoader<'_>, best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError>; } diff --git a/src/optimizer/core/statistics_meta.rs b/src/optimizer/core/statistics_meta.rs index 7697ddb8..468854d3 100644 --- a/src/optimizer/core/statistics_meta.rs +++ b/src/optimizer/core/statistics_meta.rs @@ -17,20 +17,19 @@ use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; use crate::optimizer::core::cm_sketch::CountMinSketch; use crate::optimizer::core::histogram::{Bucket, Histogram, HistogramMeta}; -use crate::storage::{StatisticsMetaCache, Transaction}; +use crate::storage::StatisticsMetaCache; use crate::types::index::IndexId; use crate::types::value::DataValue; use kite_sql_serde_macros::ReferenceSerialization; use std::slice; -pub struct StatisticMetaLoader<'a, T: Transaction> { +pub struct StatisticMetaLoader<'a> { cache: &'a StatisticsMetaCache, - tx: &'a T, } -impl<'a, T: Transaction> StatisticMetaLoader<'a, T> { - pub fn new(tx: &'a T, cache: &'a StatisticsMetaCache) -> StatisticMetaLoader<'a, T> { - StatisticMetaLoader { cache, tx } +impl<'a> StatisticMetaLoader<'a> { + pub fn new(cache: &'a StatisticsMetaCache) -> StatisticMetaLoader<'a> { + StatisticMetaLoader { cache } } pub fn load( @@ -39,19 +38,7 @@ impl<'a, T: Transaction> StatisticMetaLoader<'a, T> { index_id: IndexId, ) -> Result, DatabaseError> { let key = (table_name.clone(), index_id); - match self.cache.get(&key) { - Some(Some(entry)) => return Ok(Some(entry)), - Some(None) => return Ok(None), - _ => {} - } - - let Some(statistics_meta) = self.tx.statistics_meta(table_name.as_ref(), index_id)? else { - self.cache.put(key, None); - return Ok(None); - }; - self.cache.put(key.clone(), Some(statistics_meta)); - - Ok(self.cache.get(&key).and_then(|entry| entry.as_ref())) + Ok(self.cache.get(&key)) } pub fn collect_count( diff --git a/src/optimizer/heuristic/matcher.rs b/src/optimizer/heuristic/matcher.rs index 6ef36ab7..196155f9 100644 --- a/src/optimizer/heuristic/matcher.rs +++ b/src/optimizer/heuristic/matcher.rs @@ -89,30 +89,16 @@ mod tests { #[test] fn test_recursive() { - let all_dummy_plan = LogicalPlan { - operator: Operator::Dummy, - childrens: Box::new(Childrens::Twins { - left: Box::new(LogicalPlan { - operator: Operator::Dummy, - childrens: Box::new(Childrens::Only(Box::new(LogicalPlan { - operator: Operator::Dummy, - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }))), - physical_option: None, - _output_schema_ref: None, - }), - right: Box::new(LogicalPlan { - operator: Operator::Dummy, - childrens: Box::new(Childrens::None), - physical_option: None, - _output_schema_ref: None, - }), - }), - physical_option: None, - _output_schema_ref: None, - }; + let all_dummy_plan = LogicalPlan::new( + Operator::Dummy, + Childrens::Twins { + left: Box::new(LogicalPlan::new( + Operator::Dummy, + Childrens::Only(Box::new(LogicalPlan::new(Operator::Dummy, Childrens::None))), + )), + right: Box::new(LogicalPlan::new(Operator::Dummy, Childrens::None)), + }, + ); let only_dummy_pattern = Pattern { predicate: |p| matches!(p, Operator::Dummy), diff --git a/src/optimizer/heuristic/optimizer.rs b/src/optimizer/heuristic/optimizer.rs index 9af51686..132fbb7d 100644 --- a/src/optimizer/heuristic/optimizer.rs +++ b/src/optimizer/heuristic/optimizer.rs @@ -28,12 +28,11 @@ use crate::optimizer::rule::normalization::{ use crate::planner::operator::join::JoinCondition; use crate::planner::operator::table_scan::TableScanOperator; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; -use crate::planner::{Childrens, LogicalPlan}; -use crate::storage::Transaction; +use crate::planner::{Childrens, LogicalPlan, PlanArena}; use std::array; use std::ops::Not; -type ScanHintApplier<'a> = dyn Fn(&mut TableScanOperator) + 'a; +type ScanHintApplier<'a> = dyn Fn(&mut TableScanOperator, &PlanArena) + 'a; pub struct HepOptimizer<'a> { before_batches: &'a [HepBatch], @@ -60,27 +59,40 @@ impl<'a> HepOptimizer<'a> { } } - pub fn find_best( + pub fn find_best( mut self, - loader: Option<&StatisticMetaLoader<'_, T>>, + loader: Option<&StatisticMetaLoader<'_>>, + arena: &mut PlanArena, ) -> Result { let mut applied_rules = Vec::with_capacity(self.max_local_rules_len); - Self::apply_batches(&mut self.plan, self.before_batches, &mut applied_rules)?; + Self::apply_batches( + &mut self.plan, + self.before_batches, + &mut applied_rules, + arena, + )?; if let Some(loader) = loader { if self.implementation_index.is_empty().not() { - let apply_no_sort_hints = |_scan_op: &mut TableScanOperator| {}; - let apply_no_stream_distinct_hints = |_scan_op: &mut TableScanOperator| {}; + let apply_no_sort_hints = |_scan_op: &mut TableScanOperator, _arena: &PlanArena| {}; + let apply_no_stream_distinct_hints = + |_scan_op: &mut TableScanOperator, _arena: &PlanArena| {}; Self::annotate_hints_and_physical_options( &mut self.plan, loader, self.implementation_index, &apply_no_sort_hints, &apply_no_stream_distinct_hints, + arena, )?; } } - Self::apply_batches(&mut self.plan, self.after_batches, &mut applied_rules)?; + Self::apply_batches( + &mut self.plan, + self.after_batches, + &mut applied_rules, + arena, + )?; Ok(self.plan) } @@ -90,18 +102,19 @@ impl<'a> HepOptimizer<'a> { plan: &mut LogicalPlan, batches: &[HepBatch], applied_rules: &mut Vec, + arena: &mut PlanArena, ) -> Result<(), DatabaseError> { for batch in batches { match batch.strategy { HepBatchStrategy::MaxTimes(max_iteration) => { for _ in 0..max_iteration { - if !Self::apply_batch(plan, batch, applied_rules)? { + if !Self::apply_batch(plan, batch, applied_rules, arena)? { break; } } } HepBatchStrategy::LoopIfApplied => { - while Self::apply_batch(plan, batch, applied_rules)? {} + while Self::apply_batch(plan, batch, applied_rules, arena)? {} } } } @@ -113,18 +126,18 @@ impl<'a> HepOptimizer<'a> { plan: &mut LogicalPlan, batch: &HepBatch, applied_rules: &mut Vec, + arena: &mut PlanArena, ) -> Result { let mut applied = false; for step in &batch.steps { match step { HepBatchStep::WholeTree(pass) => { - if Self::apply_whole_tree_pass(plan, pass)? { - plan.reset_output_schema_cache_recursive(); + if Self::apply_whole_tree_pass(plan, pass, arena)? { applied = true; } } HepBatchStep::LocalRewrite(rules) => { - if Self::apply_local_rules(plan, rules, applied_rules)? { + if Self::apply_local_rules(plan, rules, applied_rules, arena)? { applied = true; } } @@ -136,12 +149,16 @@ impl<'a> HepOptimizer<'a> { fn apply_whole_tree_pass( plan: &mut LogicalPlan, pass: &HepWholeTreePass, + arena: &mut PlanArena, ) -> Result { match pass.kind { WholeTreePassKind::ColumnPruning => { let mut applied = false; for rule in &pass.rules { - applied |= rule.apply(plan)?; + applied |= rule.apply(plan, arena)?; + } + if applied { + plan.reset_output_schema_cache_recursive(); } Ok(applied) } @@ -159,7 +176,9 @@ impl<'a> HepOptimizer<'a> { plan, has_constant_calculation, has_evaluator_bind, + arena, )?; + plan.reset_output_schema_cache_recursive(); Ok(true) } } @@ -169,11 +188,13 @@ impl<'a> HepOptimizer<'a> { plan: &mut LogicalPlan, has_constant_calculation: bool, has_evaluator_bind: bool, + arena: &mut PlanArena, ) -> Result<(), DatabaseError> { Self::apply_expression_rewrite_pass_inner( plan, has_constant_calculation, has_evaluator_bind, + arena, ) } @@ -181,6 +202,7 @@ impl<'a> HepOptimizer<'a> { plan: &mut LogicalPlan, has_constant_calculation: bool, has_evaluator_bind: bool, + arena: &mut PlanArena, ) -> Result<(), DatabaseError> { match plan.childrens.as_mut() { Childrens::Only(child) => { @@ -188,6 +210,7 @@ impl<'a> HepOptimizer<'a> { child, has_constant_calculation, has_evaluator_bind, + arena, )?; } Childrens::Twins { left, right } => { @@ -195,36 +218,39 @@ impl<'a> HepOptimizer<'a> { left, has_constant_calculation, has_evaluator_bind, + arena, )?; Self::apply_expression_rewrite_pass_inner( right, has_constant_calculation, has_evaluator_bind, + arena, )?; } Childrens::None => {} } if has_constant_calculation { - constant_calculation_current(plan)?; + constant_calculation_current(plan, arena)?; } if has_evaluator_bind { - evaluator_bind_current(plan)?; + evaluator_bind_current(plan, arena)?; } Ok(()) } - fn annotate_hints_and_physical_options<'plan, T: Transaction>( + fn annotate_hints_and_physical_options<'plan>( plan: &'plan mut LogicalPlan, - loader: &StatisticMetaLoader<'_, T>, + loader: &StatisticMetaLoader<'_>, implementation_index: &ImplementationRuleIndex, inherited_sort_hints: &'plan ScanHintApplier<'plan>, inherited_stream_distinct_hints: &'plan ScanHintApplier<'plan>, + arena: &mut PlanArena, ) -> Result<(), DatabaseError> { if let Operator::TableScan(scan_op) = &mut plan.operator { - inherited_sort_hints(scan_op); - inherited_stream_distinct_hints(scan_op); + inherited_sort_hints(scan_op, arena); + inherited_stream_distinct_hints(scan_op, arena); } { @@ -238,7 +264,7 @@ impl<'a> HepOptimizer<'a> { } else { let mut best_physical_option: BestPhysicalOption = None; for rule in implementation_index.for_matching_operator(operator) { - rule.update_best_option(operator, loader, &mut best_physical_option)?; + rule.update_best_option(operator, arena, loader, &mut best_physical_option)?; } if let Some((option, _)) = best_physical_option { *physical_option = Some(option); @@ -263,6 +289,7 @@ impl<'a> HepOptimizer<'a> { implementation_index, child_sort_hints, child_stream_distinct_hints, + arena, ), Childrens::Twins { left, right } => { Self::annotate_hints_and_physical_options( @@ -271,6 +298,7 @@ impl<'a> HepOptimizer<'a> { implementation_index, child_sort_hints, child_stream_distinct_hints, + arena, )?; Self::annotate_hints_and_physical_options( right, @@ -278,6 +306,7 @@ impl<'a> HepOptimizer<'a> { implementation_index, child_sort_hints, child_stream_distinct_hints, + arena, ) } Childrens::None => Ok(()), @@ -286,7 +315,7 @@ impl<'a> HepOptimizer<'a> { })?; } - apply_annotated_post_rules(plan)?; + apply_annotated_post_rules(plan, arena)?; Ok(()) } @@ -307,19 +336,20 @@ impl<'a> HepOptimizer<'a> { match operator { Operator::Sort(op) => { - let child_sort_hints = |scan_op: &mut TableScanOperator| { - inherited_sort_hints(scan_op); + let child_sort_hints = |scan_op: &mut TableScanOperator, arena: &PlanArena| { + inherited_sort_hints(scan_op, arena); apply_scan_order_hint( scan_op, ScanOrderHint::sort_fields(&op.sort_fields), OrderHintKind::SortElimination, + arena, ); }; f(&child_sort_hints) } _ if propagate_hints => f(inherited_sort_hints), _ => { - let no_sort_hints = |_scan_op: &mut TableScanOperator| {}; + let no_sort_hints = |_scan_op: &mut TableScanOperator, _arena: &PlanArena| {}; f(&no_sort_hints) } } @@ -343,18 +373,21 @@ impl<'a> HepOptimizer<'a> { Operator::Aggregate(op) if op.is_distinct && op.agg_calls.is_empty() && !op.groupby_exprs.is_empty() => { - let child_stream_distinct_hints = |scan_op: &mut TableScanOperator| { - apply_scan_order_hint( - scan_op, - ScanOrderHint::distinct_groupby(&op.groupby_exprs), - OrderHintKind::StreamDistinct, - ); - }; + let child_stream_distinct_hints = + |scan_op: &mut TableScanOperator, arena: &PlanArena| { + apply_scan_order_hint( + scan_op, + ScanOrderHint::distinct_groupby(&op.groupby_exprs), + OrderHintKind::StreamDistinct, + arena, + ); + }; f(&child_stream_distinct_hints) } _ if propagate_hints => f(inherited_stream_distinct_hints), _ => { - let no_stream_distinct_hints = |_scan_op: &mut TableScanOperator| {}; + let no_stream_distinct_hints = + |_scan_op: &mut TableScanOperator, _arena: &PlanArena| {}; f(&no_stream_distinct_hints) } } @@ -364,16 +397,18 @@ impl<'a> HepOptimizer<'a> { plan: &mut LogicalPlan, rules: &HepLocalRewriteBatch, applied_rules: &mut Vec, + arena: &mut PlanArena, ) -> Result { applied_rules.clear(); applied_rules.resize(rules.len(), false); - Self::apply_local_rules_inner(plan, rules, applied_rules) + Self::apply_local_rules_inner(plan, rules, applied_rules, arena) } fn apply_local_rules_inner( plan: &mut LogicalPlan, rules: &HepLocalRewriteBatch, applied_rules: &mut [bool], + arena: &mut PlanArena, ) -> Result { let mut applied = false; let mut next_rule_idx = 0; @@ -384,31 +419,36 @@ impl<'a> HepOptimizer<'a> { if applied_rules[idx] { continue; } - let applied_rule = rule.apply(plan)?; + let applied_rule = rule.apply(plan, arena)?; if applied_rule { - plan.reset_output_schema_cache_recursive(); applied_rules[idx] = true; applied = true; + plan.reset_output_schema_cache_recursive(); } } match plan.childrens.as_mut() { Childrens::Only(child) => { - let child_applied = Self::apply_local_rules_inner(child, rules, applied_rules)?; + let child_applied = + Self::apply_local_rules_inner(child, rules, applied_rules, arena)?; applied |= child_applied; + if child_applied { + plan.reset_output_schema_cache(); + } } Childrens::Twins { left, right } => { - let left_applied = Self::apply_local_rules_inner(left, rules, applied_rules)?; - let right_applied = Self::apply_local_rules_inner(right, rules, applied_rules)?; + let left_applied = + Self::apply_local_rules_inner(left, rules, applied_rules, arena)?; + let right_applied = + Self::apply_local_rules_inner(right, rules, applied_rules, arena)?; applied |= left_applied || right_applied; + if left_applied || right_applied { + plan.reset_output_schema_cache(); + } } Childrens::None => {} } - if applied { - plan.reset_output_schema_cache(); - } - Ok(applied) } } @@ -518,12 +558,14 @@ impl ImplementationRuleIndex { Operator::Analyze(_) if self.contains(ImplementationRuleImpl::Analyze) => { Some(PhysicalOption::new(PlanImpl::Analyze, SortOption::None)) } + #[cfg(feature = "copy")] Operator::CopyFromFile(_) if self.contains(ImplementationRuleImpl::CopyFromFile) => { Some(PhysicalOption::new( PlanImpl::CopyFromFile, SortOption::None, )) } + #[cfg(feature = "copy")] Operator::CopyToFile(_) if self.contains(ImplementationRuleImpl::CopyToFile) => { Some(PhysicalOption::new(PlanImpl::CopyToFile, SortOption::None)) } @@ -579,6 +621,7 @@ mod tests { use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; use crate::expression::ScalarExpression; + use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; use crate::optimizer::rule::implementation::ImplementationRuleImpl; @@ -590,27 +633,21 @@ mod tests { use crate::types::value::DataValue; use crate::types::LogicalType; use std::ops::Bound; - use std::sync::atomic::AtomicUsize; - use std::sync::Arc; use tempfile::TempDir; #[test] fn test_find_best_selects_cheapest_scan() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let database = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - database - .run("create table t1 (c1 int primary key, c2 int)")? - .done()?; - database - .run("create table t2 (c3 int primary key, c4 int)")? - .done()?; + let mut database = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + database.ddl("create table t1 (c1 int primary key, c2 int)")?; + database.ddl("create table t2 (c3 int primary key, c4 int)")?; for i in 0..1000 { database .run(format!("insert into t1 values({}, {})", i, i + 1).as_str())? .done()?; } - database.run("analyze table t1")?.done()?; + database.analyze("t1")?; let transaction = database.storage.transaction()?; let c1_column = transaction @@ -618,8 +655,9 @@ mod tests { .unwrap() .get_column_by_name("c1") .unwrap(); + let mut plan_arena = crate::planner::PlanArena::new(database.state.table_arena()); let sort_fields = vec![SortField::new( - ScalarExpression::column_expr(c1_column.clone(), 0), + ScalarExpression::column_expr(c1_column, 0), true, false, )]; @@ -632,7 +670,6 @@ mod tests { &transaction, &scala_functions, &table_functions, - Arc::new(AtomicUsize::new(0)), ), &[], None, @@ -640,7 +677,8 @@ mod tests { let stmt = crate::parser::parse_sql( "select c1, c3 from t1 inner join t2 on c1 = c3 where (c1 > 40 or c1 = 2) and c3 > 22", )?; - let plan = binder.bind(&stmt[0])?; + let stmt = stmt.into_iter().next().unwrap(); + let plan = binder.bind(&stmt, &mut plan_arena)?; let pipeline = HepOptimizerPipeline::builder() .before_batch( "Simplify Filter".to_string(), @@ -665,9 +703,20 @@ mod tests { ]) .build(); - let best_plan = pipeline - .instantiate(plan) - .find_best(Some(&transaction.meta_loader(database.state.meta_cache())))?; + let best_plan = pipeline.instantiate(plan).find_best( + Some(&StatisticMetaLoader::new(database.state.meta_cache())), + &mut plan_arena, + )?; + + let expected_index_meta = plan_arena.alloc_index(IndexMeta { + id: 0, + column_ids: vec![plan_arena.column(c1_column).id().unwrap()], + table_name: "t1".to_string().into(), + pk_ty: LogicalType::Integer, + value_ty: LogicalType::Integer, + name: "pk_index".to_string(), + ty: IndexType::PrimaryKey { is_multiple: false }, + }); assert_eq!( best_plan @@ -681,15 +730,7 @@ mod tests { .physical_option, Some(PhysicalOption::new( PlanImpl::IndexScan(Box::new(IndexInfo { - meta: Arc::new(IndexMeta { - id: 0, - column_ids: vec![c1_column.id().unwrap()], - table_name: "t1".to_string().into(), - pk_ty: LogicalType::Integer, - value_ty: LogicalType::Integer, - name: "pk_index".to_string(), - ty: IndexType::PrimaryKey { is_multiple: false }, - }), + meta: expected_index_meta, sort_option: SortOption::OrderBy { fields: sort_fields.clone(), ignore_prefix_len: 0, diff --git a/src/optimizer/plan_utils.rs b/src/optimizer/plan_utils.rs index 39666769..fe81a625 100644 --- a/src/optimizer/plan_utils.rs +++ b/src/optimizer/plan_utils.rs @@ -40,14 +40,6 @@ pub fn only_child_mut(plan: &mut LogicalPlan) -> Option<&mut LogicalPlan> { } } -pub fn left_child(plan: &LogicalPlan) -> Option<&LogicalPlan> { - match plan.childrens.as_ref() { - Childrens::Only(child) => Some(child.as_ref()), - Childrens::Twins { left, .. } => Some(left.as_ref()), - Childrens::None => None, - } -} - #[allow(dead_code)] pub fn left_child_mut(plan: &mut LogicalPlan) -> Option<&mut LogicalPlan> { match plan.childrens.as_mut() { @@ -57,13 +49,6 @@ pub fn left_child_mut(plan: &mut LogicalPlan) -> Option<&mut LogicalPlan> { } } -pub fn right_child(plan: &LogicalPlan) -> Option<&LogicalPlan> { - match plan.childrens.as_ref() { - Childrens::Twins { right, .. } => Some(right.as_ref()), - _ => None, - } -} - #[allow(dead_code)] pub fn right_child_mut(plan: &mut LogicalPlan) -> Option<&mut LogicalPlan> { match plan.childrens.as_mut() { diff --git a/src/optimizer/rule/implementation/ddl/add_column.rs b/src/optimizer/rule/implementation/ddl/add_column.rs index bd17bb3b..824e28a4 100644 --- a/src/optimizer/rule/implementation/ddl/add_column.rs +++ b/src/optimizer/rule/implementation/ddl/add_column.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static ADD_COLUMN_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/ddl/change_column.rs b/src/optimizer/rule/implementation/ddl/change_column.rs index 443b53ed..658e2be6 100644 --- a/src/optimizer/rule/implementation/ddl/change_column.rs +++ b/src/optimizer/rule/implementation/ddl/change_column.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static CHANGE_COLUMN_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/ddl/create_table.rs b/src/optimizer/rule/implementation/ddl/create_table.rs index 267e0c19..5fc8c987 100644 --- a/src/optimizer/rule/implementation/ddl/create_table.rs +++ b/src/optimizer/rule/implementation/ddl/create_table.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static CREATE_TABLE_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/ddl/drop_column.rs b/src/optimizer/rule/implementation/ddl/drop_column.rs index f79a04d4..25646f9a 100644 --- a/src/optimizer/rule/implementation/ddl/drop_column.rs +++ b/src/optimizer/rule/implementation/ddl/drop_column.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static DROP_COLUMN_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/ddl/drop_table.rs b/src/optimizer/rule/implementation/ddl/drop_table.rs index 3914cad6..3da2881e 100644 --- a/src/optimizer/rule/implementation/ddl/drop_table.rs +++ b/src/optimizer/rule/implementation/ddl/drop_table.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static DROP_TABLE_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/ddl/truncate.rs b/src/optimizer/rule/implementation/ddl/truncate.rs index e4b31596..20752355 100644 --- a/src/optimizer/rule/implementation/ddl/truncate.rs +++ b/src/optimizer/rule/implementation/ddl/truncate.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static TRUNCATE_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dml/analyze.rs b/src/optimizer/rule/implementation/dml/analyze.rs index 6682f285..8a361947 100644 --- a/src/optimizer/rule/implementation/dml/analyze.rs +++ b/src/optimizer/rule/implementation/dml/analyze.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static ANALYZE_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dml/copy_from_file.rs b/src/optimizer/rule/implementation/dml/copy_from_file.rs index 18ab769f..7d9b821d 100644 --- a/src/optimizer/rule/implementation/dml/copy_from_file.rs +++ b/src/optimizer/rule/implementation/dml/copy_from_file.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static COPY_FROM_FILE_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dml/copy_to_file.rs b/src/optimizer/rule/implementation/dml/copy_to_file.rs index 06785524..0ef60ae5 100644 --- a/src/optimizer/rule/implementation/dml/copy_to_file.rs +++ b/src/optimizer/rule/implementation/dml/copy_to_file.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static COPY_TO_FILE_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dml/delete.rs b/src/optimizer/rule/implementation/dml/delete.rs index 845f3e92..324b820a 100644 --- a/src/optimizer/rule/implementation/dml/delete.rs +++ b/src/optimizer/rule/implementation/dml/delete.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static DELETE_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dml/insert.rs b/src/optimizer/rule/implementation/dml/insert.rs index b697450c..4b4fcf21 100644 --- a/src/optimizer/rule/implementation/dml/insert.rs +++ b/src/optimizer/rule/implementation/dml/insert.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static INSERT_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dml/mod.rs b/src/optimizer/rule/implementation/dml/mod.rs index 69e64a46..98b07283 100644 --- a/src/optimizer/rule/implementation/dml/mod.rs +++ b/src/optimizer/rule/implementation/dml/mod.rs @@ -13,7 +13,9 @@ // limitations under the License. pub(crate) mod analyze; +#[cfg(feature = "copy")] pub(crate) mod copy_from_file; +#[cfg(feature = "copy")] pub(crate) mod copy_to_file; pub(crate) mod delete; pub(crate) mod insert; diff --git a/src/optimizer/rule/implementation/dml/update.rs b/src/optimizer/rule/implementation/dml/update.rs index 55d21dd8..30e34607 100644 --- a/src/optimizer/rule/implementation/dml/update.rs +++ b/src/optimizer/rule/implementation/dml/update.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static UPDATE_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dql/aggregate.rs b/src/optimizer/rule/implementation/dql/aggregate.rs index 10a7ec68..08dd7663 100644 --- a/src/optimizer/rule/implementation/dql/aggregate.rs +++ b/src/optimizer/rule/implementation/dql/aggregate.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static GROUP_BY_AGGREGATE_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dql/dummy.rs b/src/optimizer/rule/implementation/dql/dummy.rs index afcacaae..efb0b169 100644 --- a/src/optimizer/rule/implementation/dql/dummy.rs +++ b/src/optimizer/rule/implementation/dql/dummy.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static DUMMY_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dql/filter.rs b/src/optimizer/rule/implementation/dql/filter.rs index b720c1a1..8f5f284c 100644 --- a/src/optimizer/rule/implementation/dql/filter.rs +++ b/src/optimizer/rule/implementation/dql/filter.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static FILTER_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dql/function_scan.rs b/src/optimizer/rule/implementation/dql/function_scan.rs index ea982c82..8e23711d 100644 --- a/src/optimizer/rule/implementation/dql/function_scan.rs +++ b/src/optimizer/rule/implementation/dql/function_scan.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static FUNCTION_SCAN_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dql/join.rs b/src/optimizer/rule/implementation/dql/join.rs index de6cdbdf..e81be2fd 100644 --- a/src/optimizer/rule/implementation/dql/join.rs +++ b/src/optimizer/rule/implementation/dql/join.rs @@ -18,7 +18,6 @@ use crate::optimizer::core::rule::{BestPhysicalOption, ImplementationRule, Match use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::join::{JoinCondition, JoinOperator}; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; -use crate::storage::Transaction; use std::sync::LazyLock; static JOIN_PATTERN: LazyLock = LazyLock::new(|| Pattern { @@ -35,11 +34,12 @@ impl MatchPattern for JoinImplementation { } } -impl ImplementationRule for JoinImplementation { +impl ImplementationRule for JoinImplementation { fn update_best_option( &self, op: &Operator, - _: &StatisticMetaLoader<'_, T>, + _: &crate::planner::PlanArena, + _: &StatisticMetaLoader<'_>, best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { let mut physical_option = PhysicalOption::new(PlanImpl::NestLoopJoin, SortOption::None); diff --git a/src/optimizer/rule/implementation/dql/limit.rs b/src/optimizer/rule/implementation/dql/limit.rs index 47308921..ada011c8 100644 --- a/src/optimizer/rule/implementation/dql/limit.rs +++ b/src/optimizer/rule/implementation/dql/limit.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static LIMIT_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dql/mark_apply.rs b/src/optimizer/rule/implementation/dql/mark_apply.rs index 43f21691..014ca9d4 100644 --- a/src/optimizer/rule/implementation/dql/mark_apply.rs +++ b/src/optimizer/rule/implementation/dql/mark_apply.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static MARK_APPLY_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dql/projection.rs b/src/optimizer/rule/implementation/dql/projection.rs index 672d81a4..951ad627 100644 --- a/src/optimizer/rule/implementation/dql/projection.rs +++ b/src/optimizer/rule/implementation/dql/projection.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static PROJECTION_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dql/scalar_apply.rs b/src/optimizer/rule/implementation/dql/scalar_apply.rs index 8ac4f5b9..a77a2714 100644 --- a/src/optimizer/rule/implementation/dql/scalar_apply.rs +++ b/src/optimizer/rule/implementation/dql/scalar_apply.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static SCALAR_APPLY_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dql/scalar_subquery.rs b/src/optimizer/rule/implementation/dql/scalar_subquery.rs index 75a04732..64e85738 100644 --- a/src/optimizer/rule/implementation/dql/scalar_subquery.rs +++ b/src/optimizer/rule/implementation/dql/scalar_subquery.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static SCALAR_SUBQUERY_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/dql/sort.rs b/src/optimizer/rule/implementation/dql/sort.rs index 9b8028ed..468e1f54 100644 --- a/src/optimizer/rule/implementation/dql/sort.rs +++ b/src/optimizer/rule/implementation/dql/sort.rs @@ -17,7 +17,6 @@ use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::{BestPhysicalOption, ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; -use crate::storage::Transaction; use std::sync::LazyLock; static SORT_PATTERN: LazyLock = LazyLock::new(|| Pattern { @@ -34,11 +33,12 @@ impl MatchPattern for SortImplementation { } } -impl ImplementationRule for SortImplementation { +impl ImplementationRule for SortImplementation { fn update_best_option( &self, op: &Operator, - _: &StatisticMetaLoader<'_, T>, + _: &crate::planner::PlanArena, + _: &StatisticMetaLoader<'_>, best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { if let Operator::Sort(op) = op { diff --git a/src/optimizer/rule/implementation/dql/table_scan.rs b/src/optimizer/rule/implementation/dql/table_scan.rs index cdaf4baf..d2d725bf 100644 --- a/src/optimizer/rule/implementation/dql/table_scan.rs +++ b/src/optimizer/rule/implementation/dql/table_scan.rs @@ -17,7 +17,6 @@ use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::{BestPhysicalOption, ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; -use crate::storage::Transaction; use crate::types::index::{IndexLookup, IndexType}; use std::sync::LazyLock; @@ -35,19 +34,25 @@ impl MatchPattern for SeqScanImplementation { } } -impl ImplementationRule for SeqScanImplementation { +impl ImplementationRule for SeqScanImplementation { fn update_best_option( &self, op: &Operator, - loader: &StatisticMetaLoader, + arena: &crate::planner::PlanArena, + loader: &StatisticMetaLoader<'_>, best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { if let Operator::TableScan(scan_op) = op { let cost = scan_op .index_infos .iter() - .find(|index_info| matches!(index_info.meta.ty, IndexType::PrimaryKey { .. })) - .map(|index_info| loader.load(&scan_op.table_name, index_info.meta.id)) + .find(|index_info| { + matches!( + arena.index(index_info.meta).ty, + IndexType::PrimaryKey { .. } + ) + }) + .map(|index_info| loader.load(&scan_op.table_name, arena.index(index_info.meta).id)) .transpose()? .flatten() .map(|statistics_meta| statistics_meta.histogram().values_len()); @@ -72,11 +77,12 @@ impl MatchPattern for IndexScanImplementation { } } -impl ImplementationRule for IndexScanImplementation { +impl ImplementationRule for IndexScanImplementation { fn update_best_option( &self, op: &Operator, - loader: &StatisticMetaLoader<'_, T>, + arena: &crate::planner::PlanArena, + loader: &StatisticMetaLoader<'_>, best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { if let Operator::TableScan(scan_op) = op { @@ -86,11 +92,12 @@ impl ImplementationRule for IndexScanImplementation { }; let mut cost = None; + let index_meta = arena.index(index_info.meta); if let Some(mut row_count) = - loader.collect_count(&scan_op.table_name, index_info.meta.id, range)? + loader.collect_count(&scan_op.table_name, index_meta.id, range)? { if index_info.covered_deserializers.is_none() - && !matches!(index_info.meta.ty, IndexType::PrimaryKey { .. }) + && !matches!(index_meta.ty, IndexType::PrimaryKey { .. }) { // need to return table query(non-covering index) row_count *= 2; diff --git a/src/optimizer/rule/implementation/dql/top_k.rs b/src/optimizer/rule/implementation/dql/top_k.rs index 39c4b85b..85c5322e 100644 --- a/src/optimizer/rule/implementation/dql/top_k.rs +++ b/src/optimizer/rule/implementation/dql/top_k.rs @@ -17,7 +17,6 @@ use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; use crate::optimizer::core::rule::{BestPhysicalOption, ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; -use crate::storage::Transaction; use std::sync::LazyLock; static TOPK_PATTERN: LazyLock = LazyLock::new(|| Pattern { @@ -34,11 +33,12 @@ impl MatchPattern for TopKImplementation { } } -impl ImplementationRule for TopKImplementation { +impl ImplementationRule for TopKImplementation { fn update_best_option( &self, op: &Operator, - _: &StatisticMetaLoader<'_, T>, + _: &crate::planner::PlanArena, + _: &StatisticMetaLoader<'_>, best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { if let Operator::TopK(op) = op { diff --git a/src/optimizer/rule/implementation/dql/values.rs b/src/optimizer/rule/implementation/dql/values.rs index 69faffa8..ee4d530a 100644 --- a/src/optimizer/rule/implementation/dql/values.rs +++ b/src/optimizer/rule/implementation/dql/values.rs @@ -19,7 +19,6 @@ use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; use crate::planner::operator::{Operator, PhysicalOption, PlanImpl, SortOption}; use crate::single_mapping; -use crate::storage::Transaction; use std::sync::LazyLock; static VALUES_PATTERN: LazyLock = LazyLock::new(|| Pattern { diff --git a/src/optimizer/rule/implementation/macros.rs b/src/optimizer/rule/implementation/macros.rs index 14f86e86..15014862 100644 --- a/src/optimizer/rule/implementation/macros.rs +++ b/src/optimizer/rule/implementation/macros.rs @@ -21,11 +21,12 @@ macro_rules! single_mapping { } } - impl ImplementationRule for $ty { + impl ImplementationRule for $ty { fn update_best_option( &self, _: &Operator, - _: &StatisticMetaLoader<'_, T>, + _: &$crate::planner::PlanArena, + _: &StatisticMetaLoader<'_>, best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { //TODO: CostModel diff --git a/src/optimizer/rule/implementation/mod.rs b/src/optimizer/rule/implementation/mod.rs index 580f792e..696ee441 100644 --- a/src/optimizer/rule/implementation/mod.rs +++ b/src/optimizer/rule/implementation/mod.rs @@ -28,7 +28,9 @@ use crate::optimizer::rule::implementation::ddl::drop_column::DropColumnImplemen use crate::optimizer::rule::implementation::ddl::drop_table::DropTableImplementation; use crate::optimizer::rule::implementation::ddl::truncate::TruncateImplementation; use crate::optimizer::rule::implementation::dml::analyze::AnalyzeImplementation; +#[cfg(feature = "copy")] use crate::optimizer::rule::implementation::dml::copy_from_file::CopyFromFileImplementation; +#[cfg(feature = "copy")] use crate::optimizer::rule::implementation::dml::copy_to_file::CopyToFileImplementation; use crate::optimizer::rule::implementation::dml::delete::DeleteImplementation; use crate::optimizer::rule::implementation::dml::insert::InsertImplementation; @@ -52,7 +54,6 @@ use crate::optimizer::rule::implementation::dql::table_scan::{ use crate::optimizer::rule::implementation::dql::top_k::TopKImplementation; use crate::optimizer::rule::implementation::dql::values::ValuesImplementation; use crate::planner::operator::Operator; -use crate::storage::Transaction; #[repr(usize)] #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -72,7 +73,9 @@ pub enum ImplementationRuleRootTag { TopK, Values, Analyze, + #[cfg(feature = "copy")] CopyFromFile, + #[cfg(feature = "copy")] CopyToFile, Delete, Insert, @@ -105,7 +108,9 @@ impl ImplementationRuleRootTag { Operator::TopK(_) => Some(Self::TopK), Operator::Values(_) => Some(Self::Values), Operator::Analyze(_) => Some(Self::Analyze), + #[cfg(feature = "copy")] Operator::CopyFromFile(_) => Some(Self::CopyFromFile), + #[cfg(feature = "copy")] Operator::CopyToFile(_) => Some(Self::CopyToFile), Operator::Delete(_) => Some(Self::Delete), Operator::Insert(_) => Some(Self::Insert), @@ -151,7 +156,9 @@ pub enum ImplementationRuleImpl { Values, // DML Analyze, + #[cfg(feature = "copy")] CopyFromFile, + #[cfg(feature = "copy")] CopyToFile, Delete, Insert, @@ -184,7 +191,9 @@ impl MatchPattern for ImplementationRuleImpl { ImplementationRuleImpl::Sort => SortImplementation.pattern(), ImplementationRuleImpl::TopK => TopKImplementation.pattern(), ImplementationRuleImpl::Values => ValuesImplementation.pattern(), + #[cfg(feature = "copy")] ImplementationRuleImpl::CopyFromFile => CopyFromFileImplementation.pattern(), + #[cfg(feature = "copy")] ImplementationRuleImpl::CopyToFile => CopyToFileImplementation.pattern(), ImplementationRuleImpl::Delete => DeleteImplementation.pattern(), ImplementationRuleImpl::Insert => InsertImplementation.pattern(), @@ -222,7 +231,9 @@ impl ImplementationRuleImpl { ImplementationRuleImpl::TopK => ImplementationRuleRootTag::TopK, ImplementationRuleImpl::Values => ImplementationRuleRootTag::Values, ImplementationRuleImpl::Analyze => ImplementationRuleRootTag::Analyze, + #[cfg(feature = "copy")] ImplementationRuleImpl::CopyFromFile => ImplementationRuleRootTag::CopyFromFile, + #[cfg(feature = "copy")] ImplementationRuleImpl::CopyToFile => ImplementationRuleRootTag::CopyToFile, ImplementationRuleImpl::Delete => ImplementationRuleRootTag::Delete, ImplementationRuleImpl::Insert => ImplementationRuleRootTag::Insert, @@ -237,119 +248,173 @@ impl ImplementationRuleImpl { } } -impl ImplementationRule for ImplementationRuleImpl { +impl ImplementationRule for ImplementationRuleImpl { fn update_best_option( &self, operator: &Operator, - loader: &StatisticMetaLoader<'_, T>, + arena: &crate::planner::PlanArena, + loader: &StatisticMetaLoader<'_>, best_physical_option: &mut BestPhysicalOption, ) -> Result<(), DatabaseError> { match self { ImplementationRuleImpl::GroupByAggregate => GroupByAggregateImplementation - .update_best_option(operator, loader, best_physical_option)?, + .update_best_option(operator, arena, loader, best_physical_option)?, ImplementationRuleImpl::SimpleAggregate => SimpleAggregateImplementation - .update_best_option(operator, loader, best_physical_option)?, - ImplementationRuleImpl::Dummy => { - DummyImplementation.update_best_option(operator, loader, best_physical_option)? - } - ImplementationRuleImpl::Filter => { - FilterImplementation.update_best_option(operator, loader, best_physical_option)? - } - ImplementationRuleImpl::HashJoin => { - JoinImplementation.update_best_option(operator, loader, best_physical_option)? - } - ImplementationRuleImpl::Limit => { - LimitImplementation.update_best_option(operator, loader, best_physical_option)? - } + .update_best_option(operator, arena, loader, best_physical_option)?, + ImplementationRuleImpl::Dummy => DummyImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::Filter => FilterImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::HashJoin => JoinImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::Limit => LimitImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, ImplementationRuleImpl::MarkApply => MarkApplyImplementation.update_best_option( operator, + arena, loader, best_physical_option, )?, ImplementationRuleImpl::Projection => ProjectionImplementation.update_best_option( operator, + arena, loader, best_physical_option, )?, ImplementationRuleImpl::ScalarApply => ScalarApplyImplementation.update_best_option( operator, + arena, loader, best_physical_option, )?, ImplementationRuleImpl::ScalarSubquery => ScalarSubqueryImplementation - .update_best_option(operator, loader, best_physical_option)?, - ImplementationRuleImpl::SeqScan => { - SeqScanImplementation.update_best_option(operator, loader, best_physical_option)? - } + .update_best_option(operator, arena, loader, best_physical_option)?, + ImplementationRuleImpl::SeqScan => SeqScanImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, ImplementationRuleImpl::IndexScan => IndexScanImplementation.update_best_option( operator, + arena, loader, best_physical_option, )?, ImplementationRuleImpl::FunctionScan => FunctionScanImplementation.update_best_option( operator, + arena, loader, best_physical_option, )?, - ImplementationRuleImpl::Sort => { - SortImplementation.update_best_option(operator, loader, best_physical_option)? - } - ImplementationRuleImpl::TopK => { - TopKImplementation.update_best_option(operator, loader, best_physical_option)? - } - ImplementationRuleImpl::Values => { - ValuesImplementation.update_best_option(operator, loader, best_physical_option)? - } + ImplementationRuleImpl::Sort => SortImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::TopK => TopKImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::Values => ValuesImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, + #[cfg(feature = "copy")] ImplementationRuleImpl::CopyFromFile => CopyFromFileImplementation.update_best_option( operator, + arena, loader, best_physical_option, )?, + #[cfg(feature = "copy")] ImplementationRuleImpl::CopyToFile => CopyToFileImplementation.update_best_option( operator, + arena, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::Delete => DeleteImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::Insert => InsertImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::Update => UpdateImplementation.update_best_option( + operator, + arena, loader, best_physical_option, )?, - ImplementationRuleImpl::Delete => { - DeleteImplementation.update_best_option(operator, loader, best_physical_option)? - } - ImplementationRuleImpl::Insert => { - InsertImplementation.update_best_option(operator, loader, best_physical_option)? - } - ImplementationRuleImpl::Update => { - UpdateImplementation.update_best_option(operator, loader, best_physical_option)? - } ImplementationRuleImpl::AddColumn => AddColumnImplementation.update_best_option( operator, + arena, loader, best_physical_option, )?, ImplementationRuleImpl::ChangeColumn => ChangeColumnImplementation.update_best_option( operator, + arena, loader, best_physical_option, )?, ImplementationRuleImpl::CreateTable => CreateTableImplementation.update_best_option( operator, + arena, loader, best_physical_option, )?, ImplementationRuleImpl::DropColumn => DropColumnImplementation.update_best_option( operator, + arena, loader, best_physical_option, )?, ImplementationRuleImpl::DropTable => DropTableImplementation.update_best_option( operator, + arena, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::Truncate => TruncateImplementation.update_best_option( + operator, + arena, + loader, + best_physical_option, + )?, + ImplementationRuleImpl::Analyze => AnalyzeImplementation.update_best_option( + operator, + arena, loader, best_physical_option, )?, - ImplementationRuleImpl::Truncate => { - TruncateImplementation.update_best_option(operator, loader, best_physical_option)? - } - ImplementationRuleImpl::Analyze => { - AnalyzeImplementation.update_best_option(operator, loader, best_physical_option)? - } } Ok(()) diff --git a/src/optimizer/rule/normalization/agg_elimination.rs b/src/optimizer/rule/normalization/agg_elimination.rs index a4a93435..3e839962 100644 --- a/src/optimizer/rule/normalization/agg_elimination.rs +++ b/src/optimizer/rule/normalization/agg_elimination.rs @@ -25,7 +25,11 @@ use crate::planner::{Childrens, LogicalPlan}; pub struct EliminateRedundantSort; impl NormalizationRule for EliminateRedundantSort { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { let (sort_fields, topk_limit) = match &plan.operator { Operator::Sort(sort_op) => (sort_op.sort_fields.clone(), None), Operator::TopK(topk_op) => ( @@ -39,8 +43,8 @@ impl NormalizationRule for EliminateRedundantSort { Some(child) => child, None => return Ok(false), }; - mark_sort_preserving_indexes(child, &sort_fields); - let can_remove = ensure_index_order(child, &sort_fields); + mark_sort_preserving_indexes(child, &sort_fields, arena); + let can_remove = ensure_index_order(child, &sort_fields, arena); if !can_remove { return Ok(false); @@ -59,8 +63,12 @@ impl NormalizationRule for EliminateRedundantSort { } } -fn mark_sort_preserving_indexes(plan: &mut LogicalPlan, required: &[SortField]) { - mark_order_hint(plan, required, OrderHintKind::SortElimination); +fn mark_sort_preserving_indexes( + plan: &mut LogicalPlan, + required: &[SortField], + arena: &crate::planner::PlanArena, +) { + mark_order_hint(plan, required, OrderHintKind::SortElimination, arena); } #[derive(Copy, Clone)] @@ -85,7 +93,12 @@ impl<'a> ScanOrderHint<'a> { } } -fn mark_order_hint(plan: &mut LogicalPlan, required: &[SortField], hint: OrderHintKind) { +fn mark_order_hint( + plan: &mut LogicalPlan, + required: &[SortField], + hint: OrderHintKind, + arena: &crate::planner::PlanArena, +) { if required.is_empty() { return; } @@ -97,11 +110,11 @@ fn mark_order_hint(plan: &mut LogicalPlan, required: &[SortField], hint: OrderHi | Operator::TopK(_) | Operator::Sort(_) => { if let Childrens::Only(child) = plan.childrens.as_mut() { - mark_order_hint(child, required, hint); + mark_order_hint(child, required, hint, arena); } } Operator::TableScan(scan_op) => { - apply_scan_order_hint(scan_op, ScanOrderHint::sort_fields(required), hint); + apply_scan_order_hint(scan_op, ScanOrderHint::sort_fields(required), hint, arena); } _ => {} } @@ -111,22 +124,23 @@ pub(crate) fn apply_scan_order_hint( scan_op: &mut TableScanOperator, required: ScanOrderHint<'_>, hint: OrderHintKind, + arena: &crate::planner::PlanArena, ) { let required_from_table = match required { ScanOrderHint::SortFields(fields) => fields.iter().all(|field| { - field.expr.all_referenced_columns(true, |column| { + field.expr.all_referenced_columns(arena, |arena, column| { scan_op .columns .iter() - .any(|table_column| table_column == column) + .any(|scan_column| arena.same_column(*scan_column, *column)) }) }), ScanOrderHint::DistinctGroupBy(groupby_exprs) => groupby_exprs.iter().all(|expr| { - expr.all_referenced_columns(true, |column| { + expr.all_referenced_columns(arena, |arena, column| { scan_op .columns .iter() - .any(|table_column| table_column == column) + .any(|scan_column| arena.same_column(*scan_column, *column)) }) }), }; @@ -134,7 +148,7 @@ pub(crate) fn apply_scan_order_hint( return; } for index_info in scan_op.index_infos.iter_mut() { - if hint_covers(required, &index_info.sort_option) { + if hint_covers(required, &index_info.sort_option, arena) { let covered = hint_len(required); match hint { OrderHintKind::SortElimination => { @@ -163,12 +177,18 @@ fn hint_len(required: ScanOrderHint<'_>) -> usize { } } -fn hint_covers(required: ScanOrderHint<'_>, provided: &SortOption) -> bool { +fn hint_covers( + required: ScanOrderHint<'_>, + provided: &SortOption, + arena: &crate::planner::PlanArena, +) -> bool { match required { - ScanOrderHint::SortFields(fields) => covers(fields, provided, sort_field_matches), + ScanOrderHint::SortFields(fields) => covers(fields, provided, |required, provided| { + sort_field_matches(required, provided, arena) + }), ScanOrderHint::DistinctGroupBy(groupby_exprs) => { covers(groupby_exprs, provided, |expr, field| { - field.asc && !field.nulls_first && expr.eq_ignore_colref_pos(&field.expr) + field.asc && !field.nulls_first && expr.eq_ignore_colref_pos(&field.expr, arena) }) } } @@ -185,7 +205,11 @@ pub(crate) fn distinct_sort_fields(groupby_exprs: &[ScalarExpression]) -> Vec Result { + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { let Operator::Aggregate(op) = &plan.operator else { return Ok(false); }; @@ -207,7 +231,7 @@ impl NormalizationRule for UseStreamDistinct { Some(child) => child, None => return Ok(false), }; - if !ensure_stream_distinct_order(child, &required) { + if !ensure_stream_distinct_order(child, &required, arena) { return Ok(false); } @@ -219,27 +243,35 @@ impl NormalizationRule for UseStreamDistinct { } } -pub(crate) fn apply_annotated_post_rules(plan: &mut LogicalPlan) -> Result { +pub(crate) fn apply_annotated_post_rules( + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, +) -> Result { let mut changed = false; - if EliminateRedundantSort.apply(plan)? { - plan.reset_output_schema_cache_recursive(); + if EliminateRedundantSort.apply(plan, arena)? { changed = true; } - if UseStreamDistinct.apply(plan)? { + if UseStreamDistinct.apply(plan, arena)? { changed = true; } Ok(changed) } -fn ensure_stream_distinct_order(plan: &mut LogicalPlan, required: &[SortField]) -> bool { +fn ensure_stream_distinct_order( + plan: &mut LogicalPlan, + required: &[SortField], + arena: &crate::planner::PlanArena, +) -> bool { if let Some(PhysicalOption { plan: PlanImpl::IndexScan(index_info), .. }) = plan.physical_option.as_ref() { - if covers(required, &index_info.sort_option, sort_field_matches) { + if covers(required, &index_info.sort_option, |required, provided| { + sort_field_matches(required, provided, arena) + }) { return true; } } @@ -247,14 +279,18 @@ fn ensure_stream_distinct_order(plan: &mut LogicalPlan, required: &[SortField]) if let Some(physical_option) = plan.physical_option.as_ref() { match physical_option.sort_option() { SortOption::OrderBy { .. } - if covers(required, physical_option.sort_option(), sort_field_matches) => + if covers( + required, + physical_option.sort_option(), + |required, provided| sort_field_matches(required, provided, arena), + ) => { return true } SortOption::OrderBy { .. } => {} SortOption::Follow => { if let Childrens::Only(child) = plan.childrens.as_mut() { - if ensure_stream_distinct_order(child, required) { + if ensure_stream_distinct_order(child, required, arena) { return true; } } @@ -266,13 +302,19 @@ fn ensure_stream_distinct_order(plan: &mut LogicalPlan, required: &[SortField]) false } -fn ensure_index_order(plan: &mut LogicalPlan, required: &[SortField]) -> bool { +fn ensure_index_order( + plan: &mut LogicalPlan, + required: &[SortField], + arena: &crate::planner::PlanArena, +) -> bool { if let Some(PhysicalOption { plan: PlanImpl::IndexScan(index_info), .. }) = plan.physical_option.as_ref() { - if covers(required, &index_info.sort_option, sort_field_matches) { + if covers(required, &index_info.sort_option, |required, provided| { + sort_field_matches(required, provided, arena) + }) { return true; } } @@ -280,7 +322,7 @@ fn ensure_index_order(plan: &mut LogicalPlan, required: &[SortField]) -> bool { if let Some(physical_option) = plan.physical_option.as_ref() { if matches!(physical_option.sort_option(), SortOption::Follow) { if let Childrens::Only(child) = plan.childrens.as_mut() { - if ensure_index_order(child, required) { + if ensure_index_order(child, required, arena) { return true; } } @@ -290,10 +332,14 @@ fn ensure_index_order(plan: &mut LogicalPlan, required: &[SortField]) -> bool { false } -fn sort_field_matches(required: &SortField, provided: &SortField) -> bool { +fn sort_field_matches( + required: &SortField, + provided: &SortField, + arena: &crate::planner::PlanArena, +) -> bool { required.asc == provided.asc && required.nulls_first == provided.nulls_first - && required.expr.eq_ignore_colref_pos(&provided.expr) + && required.expr.eq_ignore_colref_pos(&provided.expr, arena) } pub(crate) fn covers( @@ -336,7 +382,7 @@ pub(crate) fn covers( #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { use super::{EliminateRedundantSort, UseStreamDistinct}; - use crate::catalog::{ColumnCatalog, ColumnRef, TableName}; + use crate::catalog::{ColumnCatalog, TableName}; use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; use crate::expression::ScalarExpression; @@ -352,24 +398,28 @@ mod tests { use crate::types::value::DataValue; use crate::types::LogicalType; use std::ops::Bound; - use std::sync::Arc; use ulid::Ulid; - - fn make_sort_field(name: &str) -> SortField { - make_sort_field_with_position(name, 0) + fn make_sort_field(arena: &mut crate::planner::PlanArena, name: &str) -> SortField { + make_sort_field_with_position(arena, name, 0) } - fn make_sort_field_with_position(name: &str, position: usize) -> SortField { - let column = ColumnRef::from(ColumnCatalog::new_dummy(name.to_string())); + fn make_sort_field_with_position( + arena: &mut crate::planner::PlanArena, + name: &str, + position: usize, + ) -> SortField { + let column = arena.alloc_column(ColumnCatalog::new_dummy(name.to_string())); SortField::new(ScalarExpression::column_expr(column, position), true, false) } fn build_plan( + arena: &mut crate::planner::PlanArena, required_fields: Vec, index_fields: Vec, ignore_prefix_len: usize, ) -> LogicalPlan { - let (index_info, index_sort_option) = build_index_info(index_fields, ignore_prefix_len); + let (index_info, index_sort_option) = + build_index_info(arena, index_fields, ignore_prefix_len); let mut leaf = LogicalPlan::new(Operator::Dummy, Childrens::None); leaf.physical_option = Some(PhysicalOption::new( @@ -397,6 +447,7 @@ mod tests { } fn build_index_info( + arena: &mut crate::planner::PlanArena, index_fields: Vec, ignore_prefix_len: usize, ) -> (IndexInfo, SortOption) { @@ -406,7 +457,7 @@ mod tests { ignore_prefix_len, }; let table_name: TableName = ::std::sync::Arc::from("t1"); - let meta = Arc::new(IndexMeta { + let meta = arena.alloc_index(IndexMeta { id: 1, column_ids: (0..len).map(|_| Ulid::new()).collect(), table_name, @@ -429,14 +480,16 @@ mod tests { (index_info, sort_option) } - fn build_distinct_scan_plan() -> (LogicalPlan, SortOption) { + fn build_distinct_scan_plan( + arena: &mut crate::planner::PlanArena, + ) -> (LogicalPlan, SortOption) { let table_name: TableName = ::std::sync::Arc::from("t1"); - let c1 = ColumnRef::from(ColumnCatalog::new_dummy("c1".to_string())); + let c1 = arena.alloc_column(ColumnCatalog::new_dummy("c1".to_string())); let c1_id = Ulid::new(); - let columns = vec![c1.clone()]; + let columns = vec![c1]; let sort_fields = vec![SortField::new( - ScalarExpression::column_expr(c1.clone(), 0), + ScalarExpression::column_expr(c1, 0), true, false, )]; @@ -445,7 +498,7 @@ mod tests { ignore_prefix_len: 0, }; let index_info = IndexInfo { - meta: Arc::new(IndexMeta { + meta: arena.alloc_index(IndexMeta { id: 1, column_ids: vec![c1_id], table_name: table_name.clone(), @@ -487,19 +540,28 @@ mod tests { #[test] fn remove_sort_when_index_matches_order() -> Result<(), DatabaseError> { - let sort_field = make_sort_field("c1"); - let mut plan = build_plan(vec![sort_field.clone()], vec![sort_field], 0); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = crate::planner::PlanArena::new(&table_arena); + let sort_field = make_sort_field(&mut arena, "c1"); + let mut plan = build_plan(&mut arena, vec![sort_field.clone()], vec![sort_field], 0); let rule = EliminateRedundantSort; - assert!(rule.apply(&mut plan)?); + assert!(rule.apply(&mut plan, &mut arena)?); assert!(matches!(plan.operator, Operator::Filter(_))); Ok(()) } #[test] fn remove_topk_when_index_matches_order() -> Result<(), DatabaseError> { - let sort_field = make_sort_field("c1"); - let mut plan = build_plan(vec![sort_field.clone()], vec![sort_field.clone()], 0); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = crate::planner::PlanArena::new(&table_arena); + let sort_field = make_sort_field(&mut arena, "c1"); + let mut plan = build_plan( + &mut arena, + vec![sort_field.clone()], + vec![sort_field.clone()], + 0, + ); plan.operator = Operator::TopK(TopKOperator { sort_fields: vec![sort_field], limit: 10, @@ -507,7 +569,7 @@ mod tests { }); let rule = EliminateRedundantSort; - assert!(rule.apply(&mut plan)?); + assert!(rule.apply(&mut plan, &mut arena)?); match plan.operator { Operator::Limit(limit_op) => { assert_eq!(limit_op.limit, Some(10)); @@ -520,25 +582,30 @@ mod tests { #[test] fn remove_sort_when_prefix_can_be_ignored() -> Result<(), DatabaseError> { - let c1 = make_sort_field("c1"); - let c2 = make_sort_field("c2"); - let mut plan = build_plan(vec![c2.clone()], vec![c1, c2.clone()], 1); - super::mark_sort_preserving_indexes(&mut plan, &[c2]); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = crate::planner::PlanArena::new(&table_arena); + let c1 = make_sort_field(&mut arena, "c1"); + let c2 = make_sort_field(&mut arena, "c2"); + let mut plan = build_plan(&mut arena, vec![c2.clone()], vec![c1, c2.clone()], 1); + super::mark_sort_preserving_indexes(&mut plan, &[c2], &arena); let rule = EliminateRedundantSort; - assert!(rule.apply(&mut plan)?); + assert!(rule.apply(&mut plan, &mut arena)?); Ok(()) } #[test] fn remove_topk_when_index_matches_same_column_with_different_positions( ) -> Result<(), DatabaseError> { - let required = make_sort_field_with_position("no_o_id", 0); - let provided_prefix_1 = make_sort_field_with_position("no_w_id", 0); - let provided_prefix_2 = make_sort_field_with_position("no_d_id", 1); - let provided_target = make_sort_field_with_position("no_o_id", 2); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = crate::planner::PlanArena::new(&table_arena); + let required = make_sort_field_with_position(&mut arena, "no_o_id", 0); + let provided_prefix_1 = make_sort_field_with_position(&mut arena, "no_w_id", 0); + let provided_prefix_2 = make_sort_field_with_position(&mut arena, "no_d_id", 1); + let provided_target = make_sort_field_with_position(&mut arena, "no_o_id", 2); let mut plan = build_plan( + &mut arena, vec![required.clone()], vec![provided_prefix_1, provided_prefix_2, provided_target], 2, @@ -550,20 +617,18 @@ mod tests { }); let rule = EliminateRedundantSort; - assert!(rule.apply(&mut plan)?); + assert!(rule.apply(&mut plan, &mut arena)?); assert!(matches!(plan.operator, Operator::Limit(_))); Ok(()) } #[test] fn annotate_sets_sort_hint_on_table_scan() -> Result<(), DatabaseError> { - let column = ColumnRef::from(ColumnCatalog::new_dummy("c1".to_string())); - let sort_field = SortField::new( - ScalarExpression::column_expr(column.clone(), 0), - true, - false, - ); - let (index_info, _) = build_index_info(vec![sort_field.clone()], 0); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = crate::planner::PlanArena::new(&table_arena); + let column = arena.alloc_column(ColumnCatalog::new_dummy("c1".to_string())); + let sort_field = SortField::new(ScalarExpression::column_expr(column, 0), true, false); + let (index_info, _) = build_index_info(&mut arena, vec![sort_field.clone()], 0); let columns = vec![column]; let table_name: TableName = ::std::sync::Arc::from("t"); @@ -590,7 +655,7 @@ mod tests { Operator::Sort(sort_op) => sort_op.sort_fields.clone(), _ => unreachable!("expected sort operator"), }; - super::mark_sort_preserving_indexes(&mut plan, &sort_fields); + super::mark_sort_preserving_indexes(&mut plan, &sort_fields, &arena); let table_plan = plan.childrens.pop_only(); match table_plan.operator { @@ -608,13 +673,20 @@ mod tests { #[test] fn annotate_sets_stream_distinct_hint_on_table_scan() -> Result<(), DatabaseError> { - let (mut plan, _) = build_distinct_scan_plan(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = crate::planner::PlanArena::new(&table_arena); + let (mut plan, _) = build_distinct_scan_plan(&mut arena); let required = match &plan.operator { Operator::Aggregate(op) => super::distinct_sort_fields(&op.groupby_exprs), _ => unreachable!("expected aggregate operator"), }; if let Childrens::Only(child) = plan.childrens.as_mut() { - super::mark_order_hint(child, &required, super::OrderHintKind::StreamDistinct); + super::mark_order_hint( + child, + &required, + super::OrderHintKind::StreamDistinct, + &arena, + ); } let child = plan.childrens.pop_only(); @@ -629,7 +701,9 @@ mod tests { #[test] fn use_stream_distinct_when_order_satisfied() -> Result<(), DatabaseError> { - let (mut plan, sort_option) = build_distinct_scan_plan(); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = crate::planner::PlanArena::new(&table_arena); + let (mut plan, sort_option) = build_distinct_scan_plan(&mut arena); if let Childrens::Only(child) = plan.childrens.as_mut() { if let Operator::TableScan(scan_op) = &child.operator { let index_info = scan_op.index_infos[0].clone(); @@ -645,7 +719,7 @@ mod tests { )); let rule = UseStreamDistinct; - assert!(rule.apply(&mut plan)?); + assert!(rule.apply(&mut plan, &mut arena)?); assert!(matches!( plan.physical_option, Some(PhysicalOption { @@ -658,26 +732,31 @@ mod tests { #[test] fn keep_sort_when_order_not_covered() -> Result<(), DatabaseError> { - let c1 = make_sort_field("c1"); - let c2 = make_sort_field("c2"); - let mut plan = build_plan(vec![c2.clone()], vec![c1.clone(), c2.clone()], 0); - super::mark_sort_preserving_indexes(&mut plan, &[c2]); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = crate::planner::PlanArena::new(&table_arena); + let c1 = make_sort_field(&mut arena, "c1"); + let c2 = make_sort_field(&mut arena, "c2"); + let mut plan = build_plan( + &mut arena, + vec![c2.clone()], + vec![c1.clone(), c2.clone()], + 0, + ); + super::mark_sort_preserving_indexes(&mut plan, &[c2], &arena); let rule = EliminateRedundantSort; - assert!(!rule.apply(&mut plan)?); + assert!(!rule.apply(&mut plan, &mut arena)?); assert!(matches!(plan.operator, Operator::Sort(_))); Ok(()) } #[test] fn promote_index_to_remove_sort() -> Result<(), DatabaseError> { - let column = ColumnRef::from(ColumnCatalog::new_dummy("c_first".to_string())); - let sort_field = SortField::new( - ScalarExpression::column_expr(column.clone(), 0), - true, - false, - ); - let (mut index_info, _) = build_index_info(vec![sort_field.clone()], 0); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = crate::planner::PlanArena::new(&table_arena); + let column = arena.alloc_column(ColumnCatalog::new_dummy("c_first".to_string())); + let sort_field = SortField::new(ScalarExpression::column_expr(column, 0), true, false); + let (mut index_info, _) = build_index_info(&mut arena, vec![sort_field.clone()], 0); index_info.lookup = Some(IndexLookup::Static(Range::Scope { min: Bound::Unbounded, max: Bound::Unbounded, @@ -725,9 +804,9 @@ mod tests { Operator::Sort(sort_op) => sort_op.sort_fields.clone(), _ => unreachable!("expected sort operator"), }; - super::mark_sort_preserving_indexes(&mut plan, &sort_fields); + super::mark_sort_preserving_indexes(&mut plan, &sort_fields, &arena); let rule = EliminateRedundantSort; - assert!(rule.apply(&mut plan)?); + assert!(rule.apply(&mut plan, &mut arena)?); assert!(matches!(plan.operator, Operator::Filter(_))); let table_plan = plan.childrens.pop_only(); diff --git a/src/optimizer/rule/normalization/column_pruning.rs b/src/optimizer/rule/normalization/column_pruning.rs index e778197a..1a749c9c 100644 --- a/src/optimizer/rule/normalization/column_pruning.rs +++ b/src/optimizer/rule/normalization/column_pruning.rs @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::ColumnSummary; +use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::agg::AggKind; use crate::expression::visitor::Visitor; -use crate::expression::{HasCountStar, ScalarExpression}; +use crate::expression::{AliasType, HasCountStar, ScalarExpression}; use crate::optimizer::core::rule::NormalizationRule; use crate::optimizer::rule::normalization::{remap_expr_positions, remap_exprs_positions}; use crate::planner::operator::join::JoinCondition; @@ -25,7 +25,6 @@ use crate::planner::{Childrens, LogicalPlan}; use crate::types::value::{DataValue, Utf8Type}; use crate::types::CharLengthUnits; use crate::types::LogicalType; -use std::collections::HashSet; #[derive(Clone)] pub struct ColumnPruning; @@ -35,29 +34,76 @@ struct ApplyOutcome { removed_positions: Vec, } +#[derive(Clone, Default)] +struct ReferencedColumns { + columns: Vec, +} + +impl ReferencedColumns { + fn with_arena_capacity(arena: &crate::planner::PlanArena) -> Self { + Self { + columns: Vec::with_capacity(arena.allocated_columns_len()), + } + } + + fn clear(&mut self) { + self.columns.clear(); + } + + fn insert(&mut self, column: ColumnRef, arena: &crate::planner::PlanArena) { + if let Err(index) = self.search(column, arena) { + self.columns.insert(index, column); + } + } + + fn extend( + &mut self, + columns: impl IntoIterator, + arena: &crate::planner::PlanArena, + ) { + let columns = columns.into_iter(); + self.columns.reserve(columns.size_hint().0); + for column in columns { + self.insert(column, arena); + } + } + + fn contains(&self, column: ColumnRef, arena: &crate::planner::PlanArena) -> bool { + self.search(column, arena).is_ok() + } + + fn search(&self, column: ColumnRef, arena: &crate::planner::PlanArena) -> Result { + let summary = arena.column(column).summary(); + self.columns + .binary_search_by(|candidate| arena.column(*candidate).summary().cmp(summary)) + } +} + impl ApplyOutcome { - fn new() -> Self { + fn with_arena_capacity(arena: &crate::planner::PlanArena) -> Self { Self { changed: false, - removed_positions: Vec::new(), + removed_positions: Vec::with_capacity(arena.allocated_columns_len()), } } } impl ColumnPruning { - fn extend_operator_referenced_columns<'a>( - operator: &'a Operator, - referenced_columns: &mut HashSet<&'a ColumnSummary>, - ) { + fn extend_operator_referenced_columns( + operator: &Operator, + referenced_columns: &mut ReferencedColumns, + arena: &mut crate::planner::PlanArena, + ) -> Result<(), DatabaseError> { match operator { Operator::Aggregate(op) => { Self::extend_expr_referenced_columns( op.agg_calls.iter().chain(op.groupby_exprs.iter()), referenced_columns, - ); + arena, + )?; } Operator::Filter(op) => { - Self::extend_expr_referenced_columns([&op.predicate], referenced_columns); + Self::extend_expr_referenced_columns([&op.predicate], referenced_columns, arena)?; } Operator::Join(op) => { if let JoinCondition::On { on, filter } = &op.on { @@ -65,50 +111,63 @@ impl ColumnPruning { Self::extend_expr_referenced_columns( [left_expr, right_expr], referenced_columns, - ); + arena, + )?; } if let Some(filter_expr) = filter { - Self::extend_expr_referenced_columns([filter_expr], referenced_columns); + Self::extend_expr_referenced_columns( + [filter_expr], + referenced_columns, + arena, + )?; } } } Operator::Project(op) => { - Self::extend_expr_referenced_columns(op.exprs.iter(), referenced_columns); + Self::extend_expr_referenced_columns(op.exprs.iter(), referenced_columns, arena)?; } Operator::MarkApply(op) => { - Self::extend_expr_referenced_columns(op.predicates().iter(), referenced_columns); - referenced_columns.insert(op.output_column().summary()); + Self::extend_expr_referenced_columns( + op.predicates().iter(), + referenced_columns, + arena, + )?; + referenced_columns.insert(*op.output_column(), arena); } Operator::TableScan(op) => { - referenced_columns.extend(op.columns.iter().map(|column| column.summary())); + referenced_columns.extend(op.columns.iter().copied(), arena); } Operator::FunctionScan(op) => { Self::extend_expr_referenced_columns( op.table_function.args.iter(), referenced_columns, - ); + arena, + )?; } Operator::Sort(op) => { Self::extend_expr_referenced_columns( op.sort_fields.iter().map(|field| &field.expr), referenced_columns, - ); + arena, + )?; } Operator::TopK(op) => { Self::extend_expr_referenced_columns( op.sort_fields.iter().map(|field| &field.expr), referenced_columns, - ); + arena, + )?; } Operator::Values(op) => { - referenced_columns.extend(op.schema_ref.iter().map(|column| column.summary())); + referenced_columns.extend(op.schema_ref.iter().copied(), arena); } Operator::Union(op) => { referenced_columns.extend( op.left_schema_ref .iter() .chain(op._right_schema_ref.iter()) - .map(|column| column.summary()), + .copied(), + arena, ); } Operator::SetMembership(op) => { @@ -116,11 +175,12 @@ impl ColumnPruning { op.left_schema_ref .iter() .chain(op._right_schema_ref.iter()) - .map(|column| column.summary()), + .copied(), + arena, ); } Operator::Delete(op) => { - referenced_columns.extend(op.primary_keys.iter().map(|column| column.summary())); + referenced_columns.extend(op.primary_keys.iter().copied(), arena); } Operator::Dummy | Operator::Limit(_) @@ -142,53 +202,72 @@ impl ColumnPruning { | Operator::DropTable(_) | Operator::DropView(_) | Operator::DropIndex(_) - | Operator::Truncate(_) - | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) => {} + | Operator::Truncate(_) => {} + #[cfg(feature = "copy")] + Operator::CopyFromFile(_) | Operator::CopyToFile(_) => {} } + Ok(()) } fn extend_expr_referenced_columns<'a>( exprs: impl IntoIterator, - referenced_columns: &mut HashSet<&'a ColumnSummary>, - ) { - struct ColumnSummaryCollector<'a, 'b> { - referenced_columns: &'b mut HashSet<&'a ColumnSummary>, + referenced_columns: &mut ReferencedColumns, + arena: &mut crate::planner::PlanArena, + ) -> Result<(), DatabaseError> { + struct ReferencedColumnCollector<'a, 'p> { + referenced_columns: &'a mut ReferencedColumns, + arena: &'a crate::planner::PlanArena<'p>, } - impl<'a> Visitor<'a> for ColumnSummaryCollector<'a, '_> { + impl Visitor<'_> for ReferencedColumnCollector<'_, '_> { fn visit_column_ref( &mut self, - column: &'a crate::catalog::ColumnRef, + column: &crate::catalog::ColumnRef, ) -> Result<(), DatabaseError> { - self.referenced_columns.insert(column.summary()); + self.referenced_columns.insert(*column, self.arena); Ok(()) } + + fn visit_alias( + &mut self, + expr: &ScalarExpression, + _ty: &AliasType, + ) -> Result<(), DatabaseError> { + self.visit(expr) + } } - let mut collector = ColumnSummaryCollector { referenced_columns }; + let mut collector = ReferencedColumnCollector { + referenced_columns, + arena, + }; for expr in exprs { - collector.visit(expr).unwrap(); + collector.visit(expr)?; } + Ok(()) } fn output_column_is_required( expr: &ScalarExpression, - column_references: &HashSet<&ColumnSummary>, + column_references: &ReferencedColumns, + arena: &mut crate::planner::PlanArena, ) -> bool { - column_references.contains(expr.output_column().summary()) + let output_column = expr.output_column_ref(arena); + column_references.contains(output_column, arena) } fn clear_exprs( - column_references: &HashSet<&ColumnSummary>, + column_references: &ReferencedColumns, exprs: &mut Vec, removed_positions: &mut Vec, output_start: usize, + arena: &mut crate::planner::PlanArena, ) { removed_positions.truncate(output_start); + removed_positions.reserve(exprs.len()); let mut position = 0; exprs.retain(|expr| { - let keep = Self::output_column_is_required(expr, column_references); + let keep = Self::output_column_is_required(expr, column_references, arena); if !keep { removed_positions.push(position); } @@ -264,9 +343,9 @@ impl ColumnPruning { | Operator::DropTable(_) | Operator::DropView(_) | Operator::DropIndex(_) - | Operator::Truncate(_) - | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) => {} + | Operator::Truncate(_) => {} + #[cfg(feature = "copy")] + Operator::CopyFromFile(_) | Operator::CopyToFile(_) => {} } Ok(()) @@ -283,28 +362,36 @@ impl ColumnPruning { } fn apply_only_child( - referenced_columns: HashSet<&ColumnSummary>, + referenced_columns: ReferencedColumns, all_referenced: bool, childrens: &mut Childrens, outcome: &mut ApplyOutcome, output_start: usize, + arena: &mut crate::planner::PlanArena, ) -> Result { let Childrens::Only(child) = childrens else { outcome.changed = false; outcome.removed_positions.truncate(output_start); return Ok(false); }; - Self::_apply_appending(referenced_columns, all_referenced, child.as_mut(), outcome)?; + Self::_apply_appending( + referenced_columns, + all_referenced, + child.as_mut(), + outcome, + arena, + )?; Ok(outcome.changed) } #[allow(clippy::needless_lifetimes)] fn apply_twins( - referenced_columns: HashSet<&ColumnSummary>, + referenced_columns: ReferencedColumns, all_referenced: bool, childrens: &mut Childrens, outcome: &mut ApplyOutcome, output_start: usize, + arena: &mut crate::planner::PlanArena, ) -> Result { let Childrens::Twins { left, right } = childrens else { outcome.changed = false; @@ -317,11 +404,18 @@ impl ColumnPruning { all_referenced, left.as_mut(), outcome, + arena, )?; let left_changed = outcome.changed; outcome.removed_positions.truncate(output_start); - Self::_apply_appending(referenced_columns, all_referenced, right.as_mut(), outcome)?; + Self::_apply_appending( + referenced_columns, + all_referenced, + right.as_mut(), + outcome, + arena, + )?; let right_changed = outcome.changed; outcome.removed_positions.truncate(output_start); @@ -336,20 +430,22 @@ impl ColumnPruning { } fn _apply( - required_columns: HashSet<&ColumnSummary>, + required_columns: ReferencedColumns, all_referenced: bool, plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, ) -> Result { - let mut outcome = ApplyOutcome::new(); - Self::_apply_appending(required_columns, all_referenced, plan, &mut outcome)?; + let mut outcome = ApplyOutcome::with_arena_capacity(arena); + Self::_apply_appending(required_columns, all_referenced, plan, &mut outcome, arena)?; Ok(outcome) } fn _apply_appending( - required_columns: HashSet<&ColumnSummary>, + mut required_columns: ReferencedColumns, all_referenced: bool, plan: &mut LogicalPlan, outcome: &mut ApplyOutcome, + arena: &mut crate::planner::PlanArena, ) -> Result<(), DatabaseError> { let mut changed = false; let output_start = outcome.removed_positions.len(); @@ -363,6 +459,7 @@ impl ColumnPruning { &mut op.agg_calls, &mut outcome.removed_positions, output_start, + arena, ); if outcome.removed_positions.len() > output_start { changed = true; @@ -390,17 +487,23 @@ impl ColumnPruning { let child_start = outcome.removed_positions.len(); let child_changed = { - let mut child_required = if op.is_distinct { - required_columns - } else { - HashSet::new() - }; + if !op.is_distinct { + required_columns.clear(); + } Self::extend_expr_referenced_columns( op.agg_calls.iter().chain(op.groupby_exprs.iter()), - &mut child_required, - ); + &mut required_columns, + arena, + )?; - Self::apply_only_child(child_required, false, childrens, outcome, child_start)? + Self::apply_only_child( + required_columns, + false, + childrens, + outcome, + child_start, + arena, + )? }; if child_changed { Self::remap_operator_after_child_change( @@ -423,6 +526,7 @@ impl ColumnPruning { &mut op.exprs, &mut outcome.removed_positions, output_start, + arena, ); if outcome.removed_positions.len() > output_start { changed = true; @@ -433,15 +537,20 @@ impl ColumnPruning { let child_start = outcome.removed_positions.len(); let child_changed = { - let mut child_required = HashSet::new(); - Self::extend_expr_referenced_columns(op.exprs.iter(), &mut child_required); + required_columns.clear(); + Self::extend_expr_referenced_columns( + op.exprs.iter(), + &mut required_columns, + arena, + )?; Self::apply_only_child( - child_required, + required_columns, false, childrens, outcome, child_start, + arena, )? }; if child_changed { @@ -459,11 +568,12 @@ impl ColumnPruning { Operator::TableScan(op) => { if !all_referenced { outcome.removed_positions.truncate(output_start); + outcome.removed_positions.reserve(op.columns.len()); let mut position = 0; op.columns.retain(|column| { let current_position = position; position += 1; - let keep = required_columns.contains(column.summary()); + let keep = required_columns.contains(*column, arena); if !keep { outcome.removed_positions.push(current_position); } @@ -486,15 +596,28 @@ impl ColumnPruning { | Operator::Union(_) | Operator::SetMembership(_) | Operator::TopK(_) => { - if matches!( - operator, - Operator::ScalarApply(_) | Operator::MarkApply(_) | Operator::Join(_) - ) { + if matches!(operator, Operator::ScalarApply(_) | Operator::MarkApply(_)) { + let mut child_required = required_columns; + Self::extend_operator_referenced_columns(operator, &mut child_required, arena)?; + changed |= Self::apply_twins( + child_required, + true, + childrens, + outcome, + output_start, + arena, + )?; + outcome.removed_positions.truncate(output_start); + } else if matches!(operator, Operator::Join(_)) { let (old_left_outputs_len, left_removed_start, right_removed_start) = { let mut child_required = required_columns; - Self::extend_operator_referenced_columns(operator, &mut child_required); + Self::extend_operator_referenced_columns( + operator, + &mut child_required, + arena, + )?; let old_left_outputs_len = match childrens { - Childrens::Twins { left, .. } => left.output_schema().len(), + Childrens::Twins { left, .. } => left.output_schema(arena).len(), _ => 0, }; let Childrens::Twins { left, right } = childrens else { @@ -509,6 +632,7 @@ impl ColumnPruning { all_referenced, left.as_mut(), outcome, + arena, )?; let left_changed = outcome.changed; let right_removed_start = outcome.removed_positions.len(); @@ -517,6 +641,7 @@ impl ColumnPruning { all_referenced, right.as_mut(), outcome, + arena, )?; changed = left_changed || outcome.changed; ( @@ -593,26 +718,32 @@ impl ColumnPruning { } } else if matches!(operator, Operator::Union(_) | Operator::SetMembership(_)) { let mut child_required = required_columns; - Self::extend_operator_referenced_columns(operator, &mut child_required); + Self::extend_operator_referenced_columns(operator, &mut child_required, arena)?; changed |= Self::apply_twins( child_required, all_referenced, childrens, outcome, output_start, + arena, )?; outcome.removed_positions.truncate(output_start); } else { let child_start = outcome.removed_positions.len(); let child_changed = { let mut child_required = required_columns; - Self::extend_operator_referenced_columns(operator, &mut child_required); + Self::extend_operator_referenced_columns( + operator, + &mut child_required, + arena, + )?; Self::apply_only_child( child_required, all_referenced, childrens, outcome, child_start, + arena, )? }; if child_changed { @@ -636,6 +767,7 @@ impl ColumnPruning { childrens, outcome, child_start, + arena, )?; if child_changed { Self::remap_operator_after_child_change( @@ -653,10 +785,21 @@ impl ColumnPruning { | Operator::Analyze(_) => { let child_start = outcome.removed_positions.len(); let child_changed = { - let mut child_required = HashSet::new(); - Self::extend_operator_referenced_columns(operator, &mut child_required); + required_columns.clear(); + Self::extend_operator_referenced_columns( + operator, + &mut required_columns, + arena, + )?; - Self::apply_only_child(child_required, true, childrens, outcome, child_start)? + Self::apply_only_child( + required_columns, + true, + childrens, + outcome, + child_start, + arena, + )? }; if child_changed { Self::remap_operator_after_child_change( @@ -677,14 +820,16 @@ impl ColumnPruning { | Operator::Truncate(_) | Operator::ShowTable | Operator::ShowView - | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) | Operator::AddColumn(_) | Operator::ChangeColumn(_) | Operator::DropColumn(_) | Operator::Describe(_) => { outcome.removed_positions.truncate(output_start); } + #[cfg(feature = "copy")] + Operator::CopyFromFile(_) | Operator::CopyToFile(_) => { + outcome.removed_positions.truncate(output_start); + } } outcome.changed = changed; @@ -693,8 +838,17 @@ impl ColumnPruning { } impl NormalizationRule for ColumnPruning { - fn apply(&self, plan: &mut LogicalPlan) -> Result { - let outcome = Self::_apply(HashSet::<&ColumnSummary>::new(), true, plan)?; + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { + let outcome = Self::_apply( + ReferencedColumns::with_arena_capacity(arena), + true, + plan, + arena, + )?; Ok(outcome.changed) } } @@ -708,12 +862,14 @@ mod tests { use crate::optimizer::rule::normalization::NormalizationRuleImpl; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; - use crate::planner::{Childrens, LogicalPlan}; - use crate::storage::rocksdb::RocksTransaction; + use crate::planner::{Childrens, LogicalPlan, PlanArena}; - fn optimize_column_pruning(sql: &str) -> Result { - let table_state = build_t1_table()?; - let plan = table_state.plan(sql)?; + fn optimize_column_pruning( + table_state: &crate::binder::test::TableState, + arena: &mut PlanArena, + sql: &str, + ) -> Result { + let plan = table_state.plan_with_arena(sql, arena)?; HepOptimizerPipeline::builder() .before_batch( @@ -723,7 +879,7 @@ mod tests { ) .build() .instantiate(plan) - .find_best::(None) + .find_best(None, arena) } fn contains_operator(plan: &LogicalPlan, predicate: impl Fn(&Operator) -> bool + Copy) -> bool { @@ -734,26 +890,36 @@ mod tests { .any(|child| contains_operator(child, predicate)) } - fn collect_scan_columns(plan: &LogicalPlan, table_name: &str, scans: &mut Vec>) { + fn collect_scan_columns( + plan: &LogicalPlan, + table_name: &str, + arena: &PlanArena, + scans: &mut Vec>, + ) { if let Operator::TableScan(op) = &plan.operator { if op.table_name.to_string() == table_name { scans.push( op.columns .iter() - .map(|column| column.name().to_string()) + .map(|column| arena.column(*column).name().to_string()) .collect(), ); } } for child in plan.childrens.iter() { - collect_scan_columns(child, table_name, scans); + collect_scan_columns(child, table_name, arena, scans); } } - fn assert_single_scan_columns(plan: &LogicalPlan, table_name: &str, expected: &[&str]) { + fn assert_single_scan_columns( + plan: &LogicalPlan, + table_name: &str, + arena: &PlanArena, + expected: &[&str], + ) { let mut scans = Vec::new(); - collect_scan_columns(plan, table_name, &mut scans); + collect_scan_columns(plan, table_name, arena, &mut scans); assert_eq!( scans.len(), 1, @@ -768,65 +934,84 @@ mod tests { #[test] fn test_column_pruning_project_single_side() -> Result<(), DatabaseError> { - let best_plan = optimize_column_pruning("select c1 from t1")?; + let table_state = build_t1_table()?; + let arena = PlanArena::new(&table_state.table_arena); + let mut arena = arena; + let best_plan = optimize_column_pruning(&table_state, &mut arena, "select c1 from t1")?; assert!(contains_operator(&best_plan, |op| matches!( op, Operator::Project(_) ))); - assert_single_scan_columns(&best_plan, "t1", &["c1"]); + assert_single_scan_columns(&best_plan, "t1", &arena, &["c1"]); Ok(()) } #[test] fn test_column_pruning_filter_single_side() -> Result<(), DatabaseError> { - let best_plan = optimize_column_pruning("select c1 from t1 where c2 > 1")?; + let table_state = build_t1_table()?; + let arena = PlanArena::new(&table_state.table_arena); + let mut arena = arena; + let best_plan = + optimize_column_pruning(&table_state, &mut arena, "select c1 from t1 where c2 > 1")?; assert!(contains_operator(&best_plan, |op| matches!( op, Operator::Filter(_) ))); - assert_single_scan_columns(&best_plan, "t1", &["c1", "c2"]); + assert_single_scan_columns(&best_plan, "t1", &arena, &["c1", "c2"]); Ok(()) } #[test] fn test_column_pruning_aggregate_single_side() -> Result<(), DatabaseError> { - let best_plan = optimize_column_pruning("select sum(c1) from t1")?; + let table_state = build_t1_table()?; + let arena = PlanArena::new(&table_state.table_arena); + let mut arena = arena; + let best_plan = + optimize_column_pruning(&table_state, &mut arena, "select sum(c1) from t1")?; assert!(contains_operator(&best_plan, |op| matches!( op, Operator::Aggregate(_) ))); - assert_single_scan_columns(&best_plan, "t1", &["c1"]); + assert_single_scan_columns(&best_plan, "t1", &arena, &["c1"]); Ok(()) } #[test] fn test_column_pruning_sort_single_side() -> Result<(), DatabaseError> { - let best_plan = optimize_column_pruning("select c1 from t1 order by c2")?; + let table_state = build_t1_table()?; + let arena = PlanArena::new(&table_state.table_arena); + let mut arena = arena; + let best_plan = + optimize_column_pruning(&table_state, &mut arena, "select c1 from t1 order by c2")?; assert!(contains_operator(&best_plan, |op| matches!( op, Operator::Sort(_) ))); - assert_single_scan_columns(&best_plan, "t1", &["c1", "c2"]); + assert_single_scan_columns(&best_plan, "t1", &arena, &["c1", "c2"]); Ok(()) } #[test] fn test_column_pruning_limit_single_side() -> Result<(), DatabaseError> { - let best_plan = optimize_column_pruning("select c1 from t1 limit 1")?; + let table_state = build_t1_table()?; + let arena = PlanArena::new(&table_state.table_arena); + let mut arena = arena; + let best_plan = + optimize_column_pruning(&table_state, &mut arena, "select c1 from t1 limit 1")?; assert!(contains_operator(&best_plan, |op| matches!( op, Operator::Limit(_) ))); - assert_single_scan_columns(&best_plan, "t1", &["c1"]); + assert_single_scan_columns(&best_plan, "t1", &arena, &["c1"]); Ok(()) } @@ -834,7 +1019,9 @@ mod tests { #[test] fn test_column_pruning() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state.plan("select c1, c3 from t1 left join t2 on c1 = c3")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state + .plan_with_arena("select c1, c3 from t1 left join t2 on c1 = c3", &mut arena)?; let pipeline = HepOptimizerPipeline::builder() .before_batch( @@ -843,9 +1030,7 @@ mod tests { vec![NormalizationRuleImpl::ColumnPruning], ) .build(); - let best_plan = pipeline - .instantiate(plan) - .find_best::(None)?; + let best_plan = pipeline.instantiate(plan).find_best(None, &mut arena)?; assert!(matches!(best_plan.childrens.as_ref(), Childrens::Only(_))); match best_plan.operator { diff --git a/src/optimizer/rule/normalization/combine_operators.rs b/src/optimizer/rule/normalization/combine_operators.rs index a548145e..edcc853d 100644 --- a/src/optimizer/rule/normalization/combine_operators.rs +++ b/src/optimizer/rule/normalization/combine_operators.rs @@ -16,7 +16,7 @@ use crate::errors::DatabaseError; use crate::expression::{AliasType, BinaryOperator, ScalarExpression}; use crate::optimizer::core::rule::NormalizationRule; use crate::optimizer::plan_utils::{only_child_mut, replace_with_only_child}; -use crate::optimizer::rule::normalization::{is_subset_exprs, strip_alias}; +use crate::optimizer::rule::normalization::strip_alias; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::project::ProjectOperator; use crate::planner::operator::Operator; @@ -38,6 +38,13 @@ fn passthrough_source_position(expr: &ScalarExpression) -> Option { } } +fn collapse_match_expr(expr: &ScalarExpression) -> &ScalarExpression { + match expr { + ScalarExpression::Alias { expr, .. } => expr, + _ => expr, + } +} + fn rewrite_column_position(expr: &mut ScalarExpression, new_position: usize) { match expr { ScalarExpression::ColumnRef { position, .. } => { @@ -56,13 +63,15 @@ fn rewrite_column_position(expr: &mut ScalarExpression, new_position: usize) { fn remap_passthrough_project_exprs( parent_exprs: &mut [ScalarExpression], child_exprs: &[ScalarExpression], + arena: &crate::planner::PlanArena, ) -> bool { let mut remapped_positions = Vec::with_capacity(parent_exprs.len()); for parent_expr in parent_exprs.iter() { + let parent_match_expr = collapse_match_expr(parent_expr); let Some(position) = child_exprs .iter() - .find(|child_expr| parent_expr.eq_ignore_colref_pos(child_expr)) + .find(|child_expr| parent_match_expr.eq_ignore_colref_pos(child_expr, arena)) .and_then(passthrough_source_position) else { return false; @@ -80,19 +89,24 @@ fn remap_passthrough_project_exprs( fn groupby_exprs_match( parent_exprs: &[ScalarExpression], child_exprs: &[ScalarExpression], + arena: &crate::planner::PlanArena, ) -> bool { parent_exprs.len() == child_exprs.len() && parent_exprs .iter() .zip(child_exprs.iter()) - .all(|(parent_expr, child_expr)| parent_expr.eq_ignore_colref_pos(child_expr)) + .all(|(parent_expr, child_expr)| parent_expr.eq_ignore_colref_pos(child_expr, arena)) } /// Combine two adjacent project operators into one. pub struct CollapseProject; impl NormalizationRule for CollapseProject { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { let Operator::Project(parent_op) = &mut plan.operator else { return Ok(false); }; @@ -105,10 +119,10 @@ impl NormalizationRule for CollapseProject { match &child.operator { Operator::Project(child_op) if is_passthrough_project(child_op) - && is_subset_exprs(&parent_op.exprs, &child_op.exprs) && remap_passthrough_project_exprs( &mut parent_op.exprs, &child_op.exprs, + arena, ) => { removed |= replace_with_only_child(child.as_mut()); @@ -125,7 +139,11 @@ impl NormalizationRule for CollapseProject { pub struct CombineFilter; impl NormalizationRule for CombineFilter { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + _: &mut crate::planner::PlanArena, + ) -> Result { let parent_filter = match mem::replace(&mut plan.operator, Operator::Dummy) { Operator::Filter(op) => op, operator => { @@ -185,7 +203,11 @@ impl NormalizationRule for CombineFilter { pub struct CollapseGroupByAgg; impl NormalizationRule for CollapseGroupByAgg { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { let can_collapse = { let LogicalPlan { operator, @@ -205,7 +227,7 @@ impl NormalizationRule for CollapseGroupByAgg { let Operator::Aggregate(child_op) = &child.operator else { return Ok(false); }; - groupby_exprs_match(&op.groupby_exprs, &child_op.groupby_exprs) + groupby_exprs_match(&op.groupby_exprs, &child_op.groupby_exprs, arena) }; if can_collapse { @@ -219,7 +241,7 @@ impl NormalizationRule for CollapseGroupByAgg { #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { use crate::binder::test::build_t1_table; - use crate::catalog::{ColumnCatalog, ColumnRef}; + use crate::catalog::ColumnCatalog; use crate::errors::DatabaseError; use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::core::rule::NormalizationRule; @@ -232,20 +254,19 @@ mod tests { use crate::planner::operator::aggregate::AggregateOperator; use crate::planner::operator::project::ProjectOperator; use crate::planner::operator::Operator; - use crate::planner::{Childrens, LogicalPlan}; - use crate::storage::rocksdb::RocksTransaction; - - fn column_expr(name: &str, position: usize) -> ScalarExpression { - ScalarExpression::column_expr( - ColumnRef::from(ColumnCatalog::new_dummy(name.to_string())), - position, - ) + use crate::planner::{Childrens, LogicalPlan, PlanArena}; + + fn column_expr(arena: &mut PlanArena, name: &str, position: usize) -> ScalarExpression { + let column = arena.alloc_column(ColumnCatalog::new_dummy(name.to_string())); + ScalarExpression::column_expr(column, position) } #[test] fn test_collapse_project() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state.plan("select c1 from (select c1, c2 from t1) t")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = + table_state.plan_with_arena("select c1 from (select c1, c2 from t1) t", &mut arena)?; let pipeline = HepOptimizerPipeline::builder() .before_batch( @@ -254,9 +275,7 @@ mod tests { vec![NormalizationRuleImpl::CollapseProject], ) .build(); - let best_plan = pipeline - .instantiate(plan) - .find_best::(None)?; + let best_plan = pipeline.instantiate(plan).find_best(None, &mut arena)?; if let Operator::Project(op) = &best_plan.operator { assert_eq!(op.exprs.len(), 1); @@ -279,7 +298,9 @@ mod tests { #[test] fn test_collapse_project_with_alias() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state.plan("select t.x from (select c1 as x from t1) t")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state + .plan_with_arena("select t.x from (select c1 as x from t1) t", &mut arena)?; let original = plan.clone(); let original_child = original.childrens.pop_only(); assert!(matches!(original_child.operator, Operator::Project(_))); @@ -293,9 +314,7 @@ mod tests { vec![NormalizationRuleImpl::CollapseProject], ) .build(); - let best_plan = pipeline - .instantiate(plan) - .find_best::(None)?; + let best_plan = pipeline.instantiate(plan).find_best(None, &mut arena)?; if let Operator::Project(op) = &best_plan.operator { assert_eq!(op.exprs.len(), 1); } else { @@ -315,20 +334,25 @@ mod tests { #[test] fn test_collapse_project_remaps_reordered_passthrough_positions() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); let child = LogicalPlan::new( Operator::Project(ProjectOperator { - exprs: vec![column_expr("c2", 1), column_expr("c1", 0)], + exprs: vec![ + column_expr(&mut arena, "c2", 1), + column_expr(&mut arena, "c1", 0), + ], }), Childrens::Only(Box::new(LogicalPlan::new(Operator::Dummy, Childrens::None))), ); let mut plan = LogicalPlan::new( Operator::Project(ProjectOperator { - exprs: vec![column_expr("c2", 0)], + exprs: vec![column_expr(&mut arena, "c2", 0)], }), Childrens::Only(Box::new(child)), ); - assert!(CollapseProject.apply(&mut plan)?); + assert!(CollapseProject.apply(&mut plan, &mut arena)?); let Operator::Project(op) = &plan.operator else { unreachable!("expected project"); @@ -347,8 +371,11 @@ mod tests { #[test] fn test_combine_filter() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = - table_state.plan("select * from (select * from t1 where c1 > 1) t where 1 = 1")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena( + "select * from (select * from t1 where c1 > 1) t where 1 = 1", + &mut arena, + )?; let pipeline = HepOptimizerPipeline::builder() .before_batch( @@ -357,9 +384,7 @@ mod tests { vec![NormalizationRuleImpl::CombineFilter], ) .build(); - let best_plan = pipeline - .instantiate(plan) - .find_best::(None)?; + let best_plan = pipeline.instantiate(plan).find_best(None, &mut arena)?; let filter_op = best_plan.childrens.pop_only(); if let Operator::Filter(op) = &filter_op.operator { @@ -378,7 +403,9 @@ mod tests { #[test] fn test_collapse_group_by_agg() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state.plan("select distinct c1, c2 from t1 group by c1, c2")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state + .plan_with_arena("select distinct c1, c2 from t1 group by c1, c2", &mut arena)?; let pipeline = HepOptimizerPipeline::builder() .before_batch( @@ -388,9 +415,7 @@ mod tests { ) .build(); - let best_plan = pipeline - .instantiate(plan) - .find_best::(None)?; + let best_plan = pipeline.instantiate(plan).find_best(None, &mut arena)?; let agg_op = best_plan.childrens.pop_only(); if let Operator::Aggregate(_) = &agg_op.operator { @@ -406,15 +431,18 @@ mod tests { #[test] fn test_collapse_group_by_agg_ignores_columnref_position() -> Result<(), DatabaseError> { + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); let child = AggregateOperator::build( LogicalPlan::new(Operator::Dummy, Childrens::None), vec![], - vec![column_expr("c2", 1)], + vec![column_expr(&mut arena, "c2", 1)], false, ); - let mut plan = AggregateOperator::build(child, vec![], vec![column_expr("c2", 0)], true); + let expr = column_expr(&mut arena, "c2", 0); + let mut plan = AggregateOperator::build(child, vec![], vec![expr], true); - assert!(CollapseGroupByAgg.apply(&mut plan)?); + assert!(CollapseGroupByAgg.apply(&mut plan, &mut arena)?); let Operator::Aggregate(op) = &plan.operator else { unreachable!("expected aggregate"); }; diff --git a/src/optimizer/rule/normalization/compilation_in_advance.rs b/src/optimizer/rule/normalization/compilation_in_advance.rs index 7b0dee7f..779d76f6 100644 --- a/src/optimizer/rule/normalization/compilation_in_advance.rs +++ b/src/optimizer/rule/normalization/compilation_in_advance.rs @@ -18,24 +18,28 @@ use crate::expression::BindEvaluator; use crate::optimizer::core::rule::NormalizationRule; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::Operator; -use crate::planner::{Childrens, LogicalPlan}; +use crate::planner::{Childrens, LogicalPlan, PlanArena}; #[derive(Clone)] pub struct EvaluatorBind; -pub(crate) fn evaluator_bind_current(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { +pub(crate) fn evaluator_bind_current( + plan: &mut LogicalPlan, + arena: &PlanArena, +) -> Result<(), DatabaseError> { let operator = &mut plan.operator; + let mut evaluator = BindEvaluator { arena }; match operator { Operator::Join(op) => { match &mut op.on { JoinCondition::On { on, filter } => { for (left_expr, right_expr) in on { - BindEvaluator.visit(left_expr)?; - BindEvaluator.visit(right_expr)?; + evaluator.visit(left_expr)?; + evaluator.visit(right_expr)?; } if let Some(expr) = filter { - BindEvaluator.visit(expr)?; + evaluator.visit(expr)?; } } JoinCondition::None => {} @@ -45,41 +49,41 @@ pub(crate) fn evaluator_bind_current(plan: &mut LogicalPlan) -> Result<(), Datab } Operator::Aggregate(op) => { for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) { - BindEvaluator.visit(expr)?; + evaluator.visit(expr)?; } } Operator::Filter(op) => { - BindEvaluator.visit(&mut op.predicate)?; + evaluator.visit(&mut op.predicate)?; } Operator::Project(op) => { for expr in op.exprs.iter_mut() { - BindEvaluator.visit(expr)?; + evaluator.visit(expr)?; } } Operator::MarkApply(op) => { for predicate in op.predicates_mut().iter_mut() { - BindEvaluator.visit(predicate)?; + evaluator.visit(predicate)?; } } Operator::ScalarApply(_) => {} Operator::Sort(op) => { for sort_field in op.sort_fields.iter_mut() { - BindEvaluator.visit(&mut sort_field.expr)?; + evaluator.visit(&mut sort_field.expr)?; } } Operator::TopK(op) => { for sort_field in op.sort_fields.iter_mut() { - BindEvaluator.visit(&mut sort_field.expr)?; + evaluator.visit(&mut sort_field.expr)?; } } Operator::FunctionScan(op) => { for expr in op.table_function.args.iter_mut() { - BindEvaluator.visit(expr)?; + evaluator.visit(expr)?; } } Operator::Update(op) => { for (_, expr) in op.value_exprs.iter_mut() { - BindEvaluator.visit(expr)?; + evaluator.visit(expr)?; } } Operator::Dummy @@ -104,21 +108,21 @@ pub(crate) fn evaluator_bind_current(plan: &mut LogicalPlan) -> Result<(), Datab | Operator::DropView(_) | Operator::DropIndex(_) | Operator::Truncate(_) - | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) | Operator::Union(_) | Operator::SetMembership(_) => (), + #[cfg(feature = "copy")] + Operator::CopyFromFile(_) | Operator::CopyToFile(_) => (), } Ok(()) } impl EvaluatorBind { - fn _apply(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { + fn _apply(plan: &mut LogicalPlan, arena: &PlanArena) -> Result<(), DatabaseError> { match plan.childrens.as_mut() { - Childrens::Only(child) => Self::_apply(child)?, + Childrens::Only(child) => Self::_apply(child, arena)?, Childrens::Twins { left, right } => { - Self::_apply(left)?; + Self::_apply(left, arena)?; if matches!( plan.operator, Operator::ScalarApply(_) @@ -127,19 +131,23 @@ impl EvaluatorBind { | Operator::Union(_) | Operator::SetMembership(_) ) { - Self::_apply(right)?; + Self::_apply(right, arena)?; } } Childrens::None => {} } - evaluator_bind_current(plan) + evaluator_bind_current(plan, arena) } } impl NormalizationRule for EvaluatorBind { - fn apply(&self, plan: &mut LogicalPlan) -> Result { - Self::_apply(plan)?; + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { + Self::_apply(plan, arena)?; Ok(true) } } diff --git a/src/optimizer/rule/normalization/min_max_top_k.rs b/src/optimizer/rule/normalization/min_max_top_k.rs index 048a39cd..b1ce260b 100644 --- a/src/optimizer/rule/normalization/min_max_top_k.rs +++ b/src/optimizer/rule/normalization/min_max_top_k.rs @@ -25,7 +25,11 @@ use crate::planner::LogicalPlan; pub struct MinMaxToTopK; impl NormalizationRule for MinMaxToTopK { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + _: &mut crate::planner::PlanArena, + ) -> Result { let Operator::Aggregate(op) = &plan.operator else { return Ok(false); }; @@ -83,7 +87,7 @@ mod tests { use crate::errors::DatabaseError; use crate::optimizer::core::rule::NormalizationRule; use crate::planner::operator::Operator; - use crate::planner::Childrens; + use crate::planner::{Childrens, PlanArena}; fn find_aggregate(plan: &crate::planner::LogicalPlan) -> &crate::planner::LogicalPlan { if matches!(plan.operator, Operator::Aggregate(_)) { @@ -110,10 +114,11 @@ mod tests { #[test] fn test_min_to_topk() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let mut plan = table_state.plan("select min(c1) from t1")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let mut plan = table_state.plan_with_arena("select min(c1) from t1", &mut arena)?; let agg_plan = find_aggregate_mut(&mut plan); - assert!(MinMaxToTopK.apply(agg_plan)?); + assert!(MinMaxToTopK.apply(agg_plan, &mut arena)?); let agg_plan = find_aggregate(&plan); let Operator::Aggregate(op) = &agg_plan.operator else { @@ -147,10 +152,11 @@ mod tests { #[test] fn test_max_to_topk() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let mut plan = table_state.plan("select max(c2) from t1")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let mut plan = table_state.plan_with_arena("select max(c2) from t1", &mut arena)?; let agg_plan = find_aggregate_mut(&mut plan); - assert!(MinMaxToTopK.apply(agg_plan)?); + assert!(MinMaxToTopK.apply(agg_plan, &mut arena)?); let agg_plan = find_aggregate(&plan); let child = match agg_plan.childrens.as_ref() { diff --git a/src/optimizer/rule/normalization/mod.rs b/src/optimizer/rule/normalization/mod.rs index 4b26db5c..011412e2 100644 --- a/src/optimizer/rule/normalization/mod.rs +++ b/src/optimizer/rule/normalization/mod.rs @@ -130,12 +130,12 @@ impl NormalizationRuleRootTag { | Operator::DropView(_) | Operator::DropIndex(_) | Operator::Truncate(_) - | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) | Operator::FunctionScan(_) | Operator::Update(_) | Operator::Union(_) | Operator::SetMembership(_) => None, + #[cfg(feature = "copy")] + Operator::CopyFromFile(_) | Operator::CopyToFile(_) => None, } } } @@ -182,31 +182,42 @@ impl NormalizationRuleImpl { } impl NormalizationRule for NormalizationRuleImpl { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { match self { - NormalizationRuleImpl::ColumnPruning => ColumnPruning.apply(plan), - NormalizationRuleImpl::CollapseProject => CollapseProject.apply(plan), - NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(plan), - NormalizationRuleImpl::CombineFilter => CombineFilter.apply(plan), - NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.apply(plan), - NormalizationRuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.apply(plan), - NormalizationRuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.apply(plan), - NormalizationRuleImpl::PushPredicateThroughJoin => PushPredicateThroughJoin.apply(plan), + NormalizationRuleImpl::ColumnPruning => ColumnPruning.apply(plan, arena), + NormalizationRuleImpl::CollapseProject => CollapseProject.apply(plan, arena), + NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(plan, arena), + NormalizationRuleImpl::CombineFilter => CombineFilter.apply(plan, arena), + NormalizationRuleImpl::LimitProjectTranspose => { + LimitProjectTranspose.apply(plan, arena) + } + NormalizationRuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.apply(plan, arena), + NormalizationRuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.apply(plan, arena), + NormalizationRuleImpl::PushPredicateThroughJoin => { + PushPredicateThroughJoin.apply(plan, arena) + } NormalizationRuleImpl::PushJoinPredicateIntoScan => { - PushJoinPredicateIntoScan.apply(plan) + PushJoinPredicateIntoScan.apply(plan, arena) + } + NormalizationRuleImpl::SimplifyFilter => SimplifyFilter.apply(plan, arena), + NormalizationRuleImpl::PushPredicateIntoScan => { + PushPredicateIntoScan.apply(plan, arena) + } + NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.apply(plan, arena), + NormalizationRuleImpl::EvaluatorBind => EvaluatorBind.apply(plan, arena), + NormalizationRuleImpl::MinMaxToTopK => MinMaxToTopK.apply(plan, arena), + NormalizationRuleImpl::TopK => TopK.apply(plan, arena), + NormalizationRuleImpl::ParameterizeMarkApply => { + ParameterizeMarkApply.apply(plan, arena) } - NormalizationRuleImpl::SimplifyFilter => SimplifyFilter.apply(plan), - NormalizationRuleImpl::PushPredicateIntoScan => PushPredicateIntoScan.apply(plan), - NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.apply(plan), - NormalizationRuleImpl::EvaluatorBind => EvaluatorBind.apply(plan), - NormalizationRuleImpl::MinMaxToTopK => MinMaxToTopK.apply(plan), - NormalizationRuleImpl::TopK => TopK.apply(plan), - NormalizationRuleImpl::ParameterizeMarkApply => ParameterizeMarkApply.apply(plan), } } } -/// Return true when left is subset of right pub(crate) fn strip_alias(expr: &ScalarExpression) -> &ScalarExpression { match expr { ScalarExpression::Alias { @@ -221,29 +232,6 @@ pub(crate) fn strip_alias(expr: &ScalarExpression) -> &ScalarExpression { } } -fn strip_all_alias(expr: &ScalarExpression) -> &ScalarExpression { - match expr { - ScalarExpression::Alias { expr, .. } => strip_all_alias(expr), - _ => expr, - } -} - -pub fn is_subset_exprs(left: &[ScalarExpression], right: &[ScalarExpression]) -> bool { - left.iter().all(|lhs| { - let lhs_stripped = strip_alias(lhs); - right.iter().any(|rhs| { - let rhs_stripped = strip_alias(rhs); - if lhs_stripped.eq_ignore_colref_pos(rhs_stripped) { - return true; - } - if matches!(lhs_stripped, ScalarExpression::ColumnRef { .. }) { - return lhs_stripped.eq_ignore_colref_pos(strip_all_alias(rhs)); - } - false - }) - }) -} - pub(crate) fn remap_position(position: &mut usize, removed_positions: &[usize]) { match removed_positions.binary_search(position) { Ok(_) => { diff --git a/src/optimizer/rule/normalization/parameterized_index.rs b/src/optimizer/rule/normalization/parameterized_index.rs index 9bb8b160..be0042f5 100644 --- a/src/optimizer/rule/normalization/parameterized_index.rs +++ b/src/optimizer/rule/normalization/parameterized_index.rs @@ -26,17 +26,22 @@ use crate::types::tuple::Schema; pub(crate) struct ParameterizeMarkApply; impl NormalizationRule for ParameterizeMarkApply { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { let (op, new_probe) = match (&mut plan.operator, plan.childrens.as_mut()) { (Operator::MarkApply(op), Childrens::Twins { left, right }) => { - let new_probe = find_parameterized_probe( - op.kind.clone(), + let probe = find_parameterized_probe( + op.kind, op.predicates(), - left.output_schema().as_ref(), - right.output_schema().as_ref(), - ) - .and_then(|(right_column, left_expr)| { - parameterize_right_subtree(right, &right_column).then_some(left_expr) + left.output_schema(arena), + right.output_schema(arena), + arena, + ); + let new_probe = probe.and_then(|(right_column, left_expr)| { + parameterize_right_subtree(right, &right_column, arena).then_some(left_expr) }); (op, new_probe) } @@ -54,14 +59,15 @@ fn find_parameterized_probe( predicates: &[ScalarExpression], left_schema: &Schema, right_schema: &Schema, + arena: &crate::planner::PlanArena, ) -> Option<(ColumnRef, ScalarExpression)> { match kind { MarkApplyKind::Exists => predicates.iter().find_map(|predicate| { - extract_parameterized_probe(predicate, left_schema, right_schema) + extract_parameterized_probe(predicate, left_schema, right_schema, arena) }), MarkApplyKind::Quantified(MarkApplyQuantifier::Any) => { predicates.first().and_then(|predicate| { - extract_parameterized_probe(predicate, left_schema, right_schema) + extract_parameterized_probe(predicate, left_schema, right_schema, arena) }) } MarkApplyKind::Quantified(MarkApplyQuantifier::All) => None, @@ -72,6 +78,7 @@ fn extract_parameterized_probe( predicate: &ScalarExpression, left_schema: &Schema, right_schema: &Schema, + arena: &crate::planner::PlanArena, ) -> Option<(ColumnRef, ScalarExpression)> { match predicate.unpack_alias_ref() { ScalarExpression::Binary { @@ -79,10 +86,22 @@ fn extract_parameterized_probe( left_expr, right_expr, .. - } => extract_parameterized_probe_side(left_expr, right_expr, left_schema, right_schema) - .or_else(|| { - extract_parameterized_probe_side(right_expr, left_expr, left_schema, right_schema) - }), + } => extract_parameterized_probe_side( + left_expr, + right_expr, + left_schema, + right_schema, + arena, + ) + .or_else(|| { + extract_parameterized_probe_side( + right_expr, + left_expr, + left_schema, + right_schema, + arena, + ) + }), _ => None, } } @@ -92,19 +111,20 @@ fn extract_parameterized_probe_side( left_expr: &ScalarExpression, left_schema: &Schema, right_schema: &Schema, + arena: &crate::planner::PlanArena, ) -> Option<(ColumnRef, ScalarExpression)> { let (right_column, _) = right_expr.unpack_alias_ref().unpack_bound_col(false)?; - if !schema_contains_column(right_schema, &right_column) { + if !schema_contains_column(right_schema, &right_column, arena) { return None; } - if !left_expr.all_referenced_columns(true, |candidate| { - schema_contains_column(left_schema, candidate) + if !left_expr.all_referenced_columns(arena, |arena, candidate| { + schema_contains_column(left_schema, candidate, arena) }) { return None; } - if left_expr.any_referenced_column(true, |candidate| { - schema_contains_column(right_schema, candidate) + if left_expr.any_referenced_column(arena, |arena, candidate| { + schema_contains_column(right_schema, candidate, arena) }) { return None; } @@ -112,13 +132,18 @@ fn extract_parameterized_probe_side( Some((right_column, left_expr.clone())) } -fn parameterize_right_subtree(plan: &mut LogicalPlan, right_column: &ColumnRef) -> bool { +fn parameterize_right_subtree( + plan: &mut LogicalPlan, + right_column: &ColumnRef, + arena: &crate::planner::PlanArena, +) -> bool { if matches!(plan.operator, Operator::TableScan(_)) { let index_info = { let Operator::TableScan(scan_op) = &mut plan.operator else { unreachable!(); }; - let Some(target_index) = pick_parameterized_index_position(scan_op, right_column) + let Some(target_index) = + pick_parameterized_index_position(scan_op, right_column, arena) else { return false; }; @@ -147,7 +172,7 @@ fn parameterize_right_subtree(plan: &mut LogicalPlan, right_column: &ColumnRef) } match plan.childrens.as_mut() { - Childrens::Only(child) => parameterize_right_subtree(child, right_column), + Childrens::Only(child) => parameterize_right_subtree(child, right_column, arena), _ => false, } } @@ -155,7 +180,9 @@ fn parameterize_right_subtree(plan: &mut LogicalPlan, right_column: &ColumnRef) fn pick_parameterized_index_position( scan_op: &TableScanOperator, right_column: &ColumnRef, + arena: &crate::planner::PlanArena, ) -> Option { + let right_column = arena.column(*right_column); let column_id = right_column.id()?; let table_name = right_column.table_name()?; @@ -168,10 +195,11 @@ fn pick_parameterized_index_position( .iter() .enumerate() .filter(|(_, index_info)| { - index_info.meta.table_name == *table_name - && index_info.meta.column_ids.first().copied() == Some(column_id) + let index_meta = arena.index(index_info.meta); + index_meta.table_name == *table_name + && index_meta.column_ids.first().copied() == Some(column_id) }) - .min_by_key(|(_, index_info)| index_priority(index_info.meta.ty)) + .min_by_key(|(_, index_info)| index_priority(arena.index(index_info.meta).ty)) .map(|(position, _)| position) } @@ -184,6 +212,12 @@ fn index_priority(index_type: IndexType) -> usize { } } -fn schema_contains_column(schema: &Schema, column: &ColumnRef) -> bool { - schema.iter().any(|candidate| candidate.same_column(column)) +fn schema_contains_column( + schema: &Schema, + column: &ColumnRef, + arena: &crate::planner::PlanArena, +) -> bool { + schema + .iter() + .any(|candidate| arena.same_column(*candidate, *column)) } diff --git a/src/optimizer/rule/normalization/pushdown_limit.rs b/src/optimizer/rule/normalization/pushdown_limit.rs index bbdb0c8d..e61779d4 100644 --- a/src/optimizer/rule/normalization/pushdown_limit.rs +++ b/src/optimizer/rule/normalization/pushdown_limit.rs @@ -22,7 +22,11 @@ use crate::planner::LogicalPlan; pub struct LimitProjectTranspose; impl NormalizationRule for LimitProjectTranspose { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + _: &mut crate::planner::PlanArena, + ) -> Result { let operator = std::mem::replace(&mut plan.operator, Operator::Dummy); let limit_op = match operator { @@ -63,7 +67,11 @@ impl NormalizationRule for LimitProjectTranspose { pub struct PushLimitThroughJoin; impl NormalizationRule for PushLimitThroughJoin { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + _: &mut crate::planner::PlanArena, + ) -> Result { let limit_op = match &plan.operator { Operator::Limit(op) => op.clone(), _ => return Ok(false), @@ -93,7 +101,11 @@ impl NormalizationRule for PushLimitThroughJoin { pub struct PushLimitIntoScan; impl NormalizationRule for PushLimitIntoScan { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + _: &mut crate::planner::PlanArena, + ) -> Result { let (offset, limit) = match &plan.operator { Operator::Limit(limit_op) => (limit_op.offset, limit_op.limit), _ => return Ok(false), @@ -119,12 +131,13 @@ mod tests { use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; use crate::optimizer::rule::normalization::NormalizationRuleImpl; use crate::planner::operator::Operator; - use crate::storage::rocksdb::RocksTransaction; + use crate::planner::PlanArena; #[test] fn test_limit_project_transpose() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state.plan("select c1, c2 from t1 limit 1")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena("select c1, c2 from t1 limit 1", &mut arena)?; let pipeline = HepOptimizerPipeline::builder() .before_batch( @@ -133,9 +146,7 @@ mod tests { vec![NormalizationRuleImpl::LimitProjectTranspose], ) .build(); - let best_plan = pipeline - .instantiate(plan) - .find_best::(None)?; + let best_plan = pipeline.instantiate(plan).find_best(None, &mut arena)?; if let Operator::Project(_) = &best_plan.operator { } else { @@ -154,7 +165,11 @@ mod tests { #[test] fn test_push_limit_through_join() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 left join t2 on c1 = c3 limit 1")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena( + "select * from t1 left join t2 on c1 = c3 limit 1", + &mut arena, + )?; let pipeline = HepOptimizerPipeline::builder() .before_batch( @@ -166,9 +181,7 @@ mod tests { ], ) .build(); - let best_plan = pipeline - .instantiate(plan) - .find_best::(None)?; + let best_plan = pipeline.instantiate(plan).find_best(None, &mut arena)?; let join_op = best_plan.childrens.pop_only().childrens.pop_only(); if let Operator::Join(_) = &join_op.operator { @@ -189,7 +202,8 @@ mod tests { #[test] fn test_push_limit_into_table_scan() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 limit 1 offset 1")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena("select * from t1 limit 1 offset 1", &mut arena)?; let pipeline = HepOptimizerPipeline::builder() .before_batch( @@ -201,9 +215,7 @@ mod tests { ], ) .build(); - let best_plan = pipeline - .instantiate(plan) - .find_best::(None)?; + let best_plan = pipeline.instantiate(plan).find_best(None, &mut arena)?; let scan_op = best_plan.childrens.pop_only(); if let Operator::TableScan(op) = &scan_op.operator { diff --git a/src/optimizer/rule/normalization/pushdown_predicates.rs b/src/optimizer/rule/normalization/pushdown_predicates.rs index b851c273..af0e49bb 100644 --- a/src/optimizer/rule/normalization/pushdown_predicates.rs +++ b/src/optimizer/rule/normalization/pushdown_predicates.rs @@ -12,19 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::ColumnRef; use crate::errors::DatabaseError; use crate::expression::range_detacher::{Range, RangeDetacher}; use crate::expression::visitor_mut::{PositionShift, VisitorMut}; use crate::expression::{BinaryOperator, ScalarExpression}; use crate::optimizer::core::rule::NormalizationRule; -use crate::optimizer::plan_utils::{ - left_child, replace_with_only_child, right_child, wrap_child_with, -}; +use crate::optimizer::plan_utils::{replace_with_only_child, wrap_child_with}; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::join::{JoinCondition, JoinType}; use crate::planner::operator::{Operator, SortOption}; -use crate::planner::{Childrens, LogicalPlan, SchemaOutput}; +use crate::planner::{Childrens, LogicalPlan}; use crate::types::index::{IndexInfo, IndexLookup, IndexMetaRef, IndexType}; use crate::types::value::DataValue; use crate::types::LogicalType; @@ -32,6 +29,8 @@ use itertools::Itertools; use std::ops::Bound; use std::{mem, slice}; +const EMPTY_SCHEMA: [crate::catalog::ColumnRef; 0] = []; + fn split_conjunctive_predicates(expr: &ScalarExpression) -> Vec { match expr { ScalarExpression::Binary { @@ -66,13 +65,6 @@ fn reduce_filters(filters: Vec, having: bool) -> Option Vec { - match plan.output_schema_direct() { - SchemaOutput::Schema(schema) => schema, - SchemaOutput::SchemaRef(schema_ref) => schema_ref.iter().cloned().collect(), - } -} - fn localize_right_filters( filters: &mut [ScalarExpression], left_len: usize, @@ -101,7 +93,11 @@ fn localize_right_filters( pub struct PushPredicateThroughJoin; impl NormalizationRule for PushPredicateThroughJoin { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { let mut applied = false; let parent_replacement = { @@ -120,43 +116,47 @@ impl NormalizationRule for PushPredicateThroughJoin { }; let join_plan = join_plan.as_mut(); - let join_op = match &join_plan.operator { - Operator::Join(op) => op, + let join_type = match &join_plan.operator { + Operator::Join(op) => op.join_type, _ => return Ok(false), }; if !matches!( - join_op.join_type, + join_type, JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter ) { return Ok(false); } - let left_columns = left_child(join_plan) - .map(plan_output_columns) - .unwrap_or_default(); - let right_columns = right_child(join_plan) - .map(plan_output_columns) - .unwrap_or_default(); - let filter_exprs = split_conjunctive_predicates(&filter_op.predicate); + let left_columns: &[crate::catalog::ColumnRef] = match join_plan.childrens.as_mut() { + Childrens::Only(left) => left.output_schema(arena), + Childrens::Twins { left, .. } => left.output_schema(arena), + Childrens::None => &EMPTY_SCHEMA, + }; let (left_filters, rest): (Vec<_>, Vec<_>) = filter_exprs.into_iter().partition(|f| { - f.all_referenced_columns(true, |column| left_columns.contains(column)) + f.all_referenced_columns(arena, |_, column| left_columns.contains(column)) }); + let left_len = left_columns.len(); + + let right_columns: &[crate::catalog::ColumnRef] = match join_plan.childrens.as_mut() { + Childrens::Twins { right, .. } => right.output_schema(arena), + _ => &EMPTY_SCHEMA, + }; let (right_filters, common_filters): (Vec<_>, Vec<_>) = rest.into_iter().partition(|f| { - f.all_referenced_columns(true, |column| right_columns.contains(column)) + f.all_referenced_columns(arena, |_, column| right_columns.contains(column)) }); let mut new_ops = (None, None, None); - let replace_filters = match join_op.join_type { + let replace_filters = match join_type { JoinType::Inner => { if let Some(left_filter_op) = reduce_filters(left_filters, filter_op.having) { new_ops.0 = Some(Operator::Filter(left_filter_op)); } let mut right_filters = right_filters; - localize_right_filters(&mut right_filters, left_columns.len())?; + localize_right_filters(&mut right_filters, left_len)?; if let Some(right_filter_op) = reduce_filters(right_filters, filter_op.having) { new_ops.1 = Some(Operator::Filter(right_filter_op)); } @@ -175,7 +175,7 @@ impl NormalizationRule for PushPredicateThroughJoin { } JoinType::RightOuter => { let mut right_filters = right_filters; - localize_right_filters(&mut right_filters, left_columns.len())?; + localize_right_filters(&mut right_filters, left_len)?; if let Some(right_filter_op) = reduce_filters(right_filters, filter_op.having) { new_ops.1 = Some(Operator::Filter(right_filter_op)); } @@ -218,7 +218,11 @@ impl NormalizationRule for PushPredicateThroughJoin { pub struct PushPredicateIntoScan; impl NormalizationRule for PushPredicateIntoScan { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { let LogicalPlan { operator, childrens, @@ -256,16 +260,19 @@ impl NormalizationRule for PushPredicateIntoScan { else { return Err(DatabaseError::InvalidIndex); }; - *lookup = match meta.ty { + let index_meta = arena.index(*meta); + *lookup = match index_meta.ty { IndexType::PrimaryKey { is_multiple: false } | IndexType::Unique - | IndexType::Normal => { - RangeDetacher::new(meta.table_name.as_ref(), &meta.column_ids[0]) - .detach(&filter_op.predicate)? - .map(IndexLookup::Static) - } + | IndexType::Normal => RangeDetacher::new( + index_meta.table_name.as_ref(), + &index_meta.column_ids[0], + arena, + ) + .detach(&filter_op.predicate)? + .map(IndexLookup::Static), IndexType::PrimaryKey { is_multiple: true } | IndexType::Composite => { - Self::composite_range(filter_op, meta, ignore_prefix_len)? + Self::composite_range(filter_op, *meta, ignore_prefix_len, arena)? .map(IndexLookup::Static) } }; @@ -280,22 +287,26 @@ impl NormalizationRule for PushPredicateIntoScan { // try index covered let mut mapping_slots = vec![usize::MAX; scan_op.columns.len()]; let mut needs_mapping = false; - let index_column_types = match &meta.value_ty { + let index_meta = arena.index(*meta); + let index_column_types = match &index_meta.value_ty { LogicalType::Tuple(tys) => tys, ty => slice::from_ref(ty), }; - let mut deserializers = Vec::with_capacity(meta.column_ids.len()); - - for (idx, column_id) in meta.column_ids.iter().enumerate() { - if let Some((scan_idx, column)) = scan_op - .columns - .iter() - .enumerate() - .find(|(_, column)| column.id().map(|id| id == *column_id).unwrap_or(false)) + let mut deserializers = Vec::with_capacity(index_meta.column_ids.len()); + + for (idx, column_id) in index_meta.column_ids.iter().enumerate() { + if let Some((scan_idx, column)) = + scan_op.columns.iter().enumerate().find(|(_, column)| { + arena + .column(**column) + .id() + .map(|id| id == *column_id) + .unwrap_or(false) + }) { mapping_slots[scan_idx] = idx; needs_mapping |= scan_idx != idx; - deserializers.push(column.datatype().serializable()); + deserializers.push(arena.column(*column).datatype().serializable()); } else { deserializers.push(index_column_types[idx].skip_serializable()); } @@ -316,16 +327,18 @@ impl NormalizationRule for PushPredicateIntoScan { impl PushPredicateIntoScan { fn composite_range( op: &FilterOperator, - meta: &mut IndexMetaRef, + meta: IndexMetaRef, ignore_prefix_len: &mut usize, + arena: &crate::planner::PlanArena, ) -> Result, DatabaseError> { + let meta = arena.index(meta); let mut res = None; let mut eq_ranges = Vec::with_capacity(meta.column_ids.len()); let mut apply_column_count = 0; for column_id in meta.column_ids.iter() { - if let Some(range) = - RangeDetacher::new(meta.table_name.as_ref(), column_id).detach(&op.predicate)? + if let Some(range) = RangeDetacher::new(meta.table_name.as_ref(), column_id, arena) + .detach(&op.predicate)? { apply_column_count += 1; @@ -374,7 +387,11 @@ impl PushPredicateIntoScan { pub struct PushJoinPredicateIntoScan; impl NormalizationRule for PushJoinPredicateIntoScan { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { let (join_type, filter_expr) = { let Operator::Join(join_op) = &mut plan.operator else { return Ok(false); @@ -394,20 +411,24 @@ impl NormalizationRule for PushJoinPredicateIntoScan { (join_op.join_type, filter_expr) }; - let left_columns = left_child(plan) - .map(plan_output_columns) - .unwrap_or_default(); - let right_columns = right_child(plan) - .map(plan_output_columns) - .unwrap_or_default(); - let filter_exprs = split_conjunctive_predicates(&filter_expr); + let left_columns: &[crate::catalog::ColumnRef] = match plan.childrens.as_mut() { + Childrens::Only(left) => left.output_schema(arena), + Childrens::Twins { left, .. } => left.output_schema(arena), + Childrens::None => &EMPTY_SCHEMA, + }; let (left_filters, rest): (Vec<_>, Vec<_>) = filter_exprs.into_iter().partition(|expr| { - expr.all_referenced_columns(true, |column| left_columns.contains(column)) + expr.all_referenced_columns(arena, |_, column| left_columns.contains(column)) }); + let left_len = left_columns.len(); + + let right_columns: &[crate::catalog::ColumnRef] = match plan.childrens.as_mut() { + Childrens::Twins { right, .. } => right.output_schema(arena), + _ => &EMPTY_SCHEMA, + }; let (right_filters, common_filters): (Vec<_>, Vec<_>) = rest.into_iter().partition(|expr| { - expr.all_referenced_columns(true, |column| right_columns.contains(column)) + expr.all_referenced_columns(arena, |_, column| right_columns.contains(column)) }); let (push_left, push_right) = match join_type { @@ -436,7 +457,7 @@ impl NormalizationRule for PushJoinPredicateIntoScan { } else { (Vec::new(), right_filters) }; - localize_right_filters(&mut right_push, left_columns.len())?; + localize_right_filters(&mut right_push, left_len)?; if let Some(filter_op) = reduce_filters(right_push, false) { new_ops.1 = Some(Operator::Filter(filter_op)); } else { @@ -476,8 +497,9 @@ impl NormalizationRule for PushJoinPredicateIntoScan { #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { + use crate::binder::test::build_t1_table; - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, TableName}; + use crate::catalog::{ColumnCatalog, ColumnDesc, TableName}; use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; use crate::expression::{BinaryOperator, ScalarExpression}; @@ -490,30 +512,28 @@ mod tests { use crate::planner::operator::join::{JoinCondition, JoinType}; use crate::planner::operator::table_scan::TableScanOperator; use crate::planner::operator::{Operator, SortOption}; - use crate::planner::{Childrens, LogicalPlan}; - use crate::storage::rocksdb::RocksTransaction; + use crate::planner::{Childrens, LogicalPlan, PlanArena}; use crate::types::index::{IndexInfo, IndexLookup, IndexMeta, IndexType}; use crate::types::value::DataValue; use crate::types::LogicalType; use std::collections::Bound; - use std::sync::Arc; use ulid::Ulid; fn apply_pipeline( plan: LogicalPlan, builder: HepOptimizerPipelineBuilder, + arena: &mut PlanArena, ) -> Result { - builder - .build() - .instantiate(plan) - .find_best::(None) + builder.build().instantiate(plan).find_best(None, arena) } #[test] fn test_push_predicate_into_scan() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; + let mut arena = PlanArena::new(&table_state.table_arena); // 1 - c2 < 0 => c2 > 1 - let plan = table_state.plan("select * from t1 where -(1 - c2) > 0")?; + let plan = + table_state.plan_with_arena("select * from t1 where -(1 - c2) > 0", &mut arena)?; let best_plan = apply_pipeline( plan, @@ -528,6 +548,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::PushPredicateIntoScan], ), + &mut arena, )?; let scan_op = best_plan.childrens.pop_only().childrens.pop_only(); @@ -551,6 +572,8 @@ mod tests { #[test] fn test_cover_mapping_matches_scan_order() -> Result<(), DatabaseError> { let table_name: TableName = ::std::sync::Arc::from("mock_table"); + let table_arena = crate::planner::TableArenaCell::default(); + let mut arena = PlanArena::new(&table_arena); let c1_id = Ulid::new(); let c2_id = Ulid::new(); let c3_id = Ulid::new(); @@ -561,7 +584,7 @@ mod tests { ColumnDesc::new(LogicalType::Integer, Some(0), false, None)?, ); c1.set_ref_table(table_name.clone(), c1_id, false); - let c1_ref = ColumnRef::from(c1.clone()); + let c1_ref = arena.alloc_column(c1); let mut c2 = ColumnCatalog::new( "c2".to_string(), @@ -569,7 +592,7 @@ mod tests { ColumnDesc::new(LogicalType::Integer, None, false, None)?, ); c2.set_ref_table(table_name.clone(), c2_id, false); - let c2_ref = ColumnRef::from(c2.clone()); + let c2_ref = arena.alloc_column(c2); let mut c3 = ColumnCatalog::new( "c3".to_string(), @@ -578,9 +601,9 @@ mod tests { ); c3.set_ref_table(table_name.clone(), c3_id, false); - let columns = vec![c1_ref.clone(), c2_ref.clone()]; + let columns = vec![c1_ref, c2_ref]; - let index_meta_reordered = Arc::new(IndexMeta { + let index_meta_reordered = arena.alloc_index(IndexMeta { id: 0, column_ids: vec![c2_id, c3_id, c1_id], table_name: table_name.clone(), @@ -593,7 +616,7 @@ mod tests { name: "idx_c2_c3_c1".to_string(), ty: IndexType::Composite, }); - let index_meta_aligned = Arc::new(IndexMeta { + let index_meta_aligned = arena.alloc_index(IndexMeta { id: 1, column_ids: vec![c1_id, c2_id], table_name: table_name.clone(), @@ -641,14 +664,14 @@ mod tests { let c1_gt = ScalarExpression::Binary { op: BinaryOperator::Gt, - left_expr: Box::new(ScalarExpression::column_expr(c1_ref.clone(), 0)), + left_expr: Box::new(ScalarExpression::column_expr(c1_ref, 0)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(0))), evaluator: None, ty: LogicalType::Boolean, }; let c2_gt = ScalarExpression::Binary { op: BinaryOperator::Gt, - left_expr: Box::new(ScalarExpression::column_expr(c2_ref.clone(), 1)), + left_expr: Box::new(ScalarExpression::column_expr(c2_ref, 1)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(0))), evaluator: None, ty: LogicalType::Boolean, @@ -677,6 +700,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::PushPredicateIntoScan], ), + &mut arena, )?; let table_scan = best_plan.childrens.pop_only(); @@ -693,7 +717,7 @@ mod tests { assert_eq!(deserializers.len(), 3); assert_eq!( deserializers[0], - c2_ref.datatype().serializable(), + arena.column(c2_ref).datatype().serializable(), "first serializer should align with c2" ); assert_eq!( @@ -703,7 +727,7 @@ mod tests { ); assert_eq!( deserializers[2], - c1_ref.datatype().serializable(), + arena.column(c1_ref).datatype().serializable(), "last serializer should align with c1" ); let mapping = reordered_index.cover_mapping.as_deref(); @@ -726,8 +750,11 @@ mod tests { #[test] fn test_push_predicate_through_join_in_left_join() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = - table_state.plan("select * from t1 left join t2 on c1 = c3 where c1 > 1 and c3 < 2")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena( + "select * from t1 left join t2 on c1 = c3 where c1 > 1 and c3 < 2", + &mut arena, + )?; let best_plan = apply_pipeline( plan, @@ -736,6 +763,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::PushPredicateThroughJoin], ), + &mut arena, )?; let filter_op = best_plan.childrens.pop_only(); @@ -772,8 +800,11 @@ mod tests { #[test] fn test_push_predicate_through_join_in_right_join() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state - .plan("select * from t1 right join t2 on c1 = c3 where c1 > 1 and c3 < 2")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena( + "select * from t1 right join t2 on c1 = c3 where c1 > 1 and c3 < 2", + &mut arena, + )?; let best_plan = apply_pipeline( plan, @@ -782,6 +813,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::PushPredicateThroughJoin], ), + &mut arena, )?; let filter_op = best_plan.childrens.pop_only(); @@ -818,8 +850,11 @@ mod tests { #[test] fn test_push_predicate_through_join_in_inner_join() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state - .plan("select * from t1 inner join t2 on c1 = c3 where c1 > 1 and c3 < 2")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena( + "select * from t1 inner join t2 on c1 = c3 where c1 > 1 and c3 < 2", + &mut arena, + )?; let best_plan = apply_pipeline( plan, @@ -828,6 +863,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::PushPredicateThroughJoin], ), + &mut arena, )?; let join_op = best_plan.childrens.pop_only(); @@ -869,8 +905,11 @@ mod tests { #[test] fn test_push_join_predicate_into_scan_inner_join() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state - .plan("select * from t1 inner join t2 on t1.c1 = t2.c3 and t1.c1 > 1 and t2.c3 < 2")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena( + "select * from t1 inner join t2 on t1.c1 = t2.c3 and t1.c1 > 1 and t2.c3 < 2", + &mut arena, + )?; let mut best_plan = apply_pipeline( plan, @@ -879,6 +918,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::PushJoinPredicateIntoScan], ), + &mut arena, )?; if matches!(best_plan.operator, Operator::Project(_)) { @@ -941,8 +981,11 @@ mod tests { #[test] fn test_push_join_predicate_left_outer_preserve_left() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = - table_state.plan("select * from t1 left join t2 on t1.c1 = t2.c3 and t1.c1 > 1")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena( + "select * from t1 left join t2 on t1.c1 = t2.c3 and t1.c1 > 1", + &mut arena, + )?; let mut best_plan = apply_pipeline( plan, @@ -951,6 +994,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::PushJoinPredicateIntoScan], ), + &mut arena, )?; if matches!(best_plan.operator, Operator::Project(_)) { @@ -985,8 +1029,11 @@ mod tests { #[test] fn test_push_join_predicate_left_outer_push_right() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = - table_state.plan("select * from t1 left join t2 on t1.c1 = t2.c3 and t2.c3 < 2")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena( + "select * from t1 left join t2 on t1.c1 = t2.c3 and t2.c3 < 2", + &mut arena, + )?; let mut best_plan = apply_pipeline( plan, @@ -995,6 +1042,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::PushJoinPredicateIntoScan], ), + &mut arena, )?; if matches!(best_plan.operator, Operator::Project(_)) { diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index 720802c9..158d54ed 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -23,37 +23,41 @@ use crate::planner::{Childrens, LogicalPlan}; #[derive(Copy, Clone)] pub struct ConstantCalculation; -pub(crate) fn constant_calculation_current(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { +pub(crate) fn constant_calculation_current( + plan: &mut LogicalPlan, + arena: &crate::planner::PlanArena, +) -> Result<(), DatabaseError> { let operator = &mut plan.operator; + let mut calculator = ConstantCalculator::new(arena); match operator { Operator::Aggregate(op) => { for expr in op.agg_calls.iter_mut().chain(op.groupby_exprs.iter_mut()) { - ConstantCalculator.visit(expr)?; + calculator.visit(expr)?; } } Operator::Filter(op) => { - ConstantCalculator.visit(&mut op.predicate)?; + calculator.visit(&mut op.predicate)?; } Operator::Join(op) => { if let JoinCondition::On { on, filter } = &mut op.on { for (left_expr, right_expr) in on { - ConstantCalculator.visit(left_expr)?; - ConstantCalculator.visit(right_expr)?; + calculator.visit(left_expr)?; + calculator.visit(right_expr)?; } if let Some(expr) = filter { - ConstantCalculator.visit(expr)?; + calculator.visit(expr)?; } } } Operator::Project(op) => { for expr in &mut op.exprs { - ConstantCalculator.visit(expr)?; + calculator.visit(expr)?; } } Operator::Sort(op) => { for field in &mut op.sort_fields { - ConstantCalculator.visit(&mut field.expr)?; + calculator.visit(&mut field.expr)?; } } _ => (), @@ -63,13 +67,16 @@ pub(crate) fn constant_calculation_current(plan: &mut LogicalPlan) -> Result<(), } impl ConstantCalculation { - fn _apply(plan: &mut LogicalPlan) -> Result<(), DatabaseError> { - constant_calculation_current(plan)?; + fn _apply( + plan: &mut LogicalPlan, + arena: &crate::planner::PlanArena, + ) -> Result<(), DatabaseError> { + constant_calculation_current(plan, arena)?; match plan.childrens.as_mut() { - Childrens::Only(child) => Self::_apply(child.as_mut())?, + Childrens::Only(child) => Self::_apply(child.as_mut(), arena)?, Childrens::Twins { left, right } => { - Self::_apply(left.as_mut())?; - Self::_apply(right.as_mut())?; + Self::_apply(left.as_mut(), arena)?; + Self::_apply(right.as_mut(), arena)?; } Childrens::None => (), } @@ -79,8 +86,12 @@ impl ConstantCalculation { } impl NormalizationRule for ConstantCalculation { - fn apply(&self, plan: &mut LogicalPlan) -> Result { - Self::_apply(plan)?; + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { + Self::_apply(plan, arena)?; Ok(true) } } @@ -103,7 +114,11 @@ fn has_aggregate_descendant(plan: &LogicalPlan) -> bool { } impl NormalizationRule for SimplifyFilter { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + arena: &mut crate::planner::PlanArena, + ) -> Result { if let Operator::Filter(filter_op) = &mut plan.operator { if filter_op.is_optimized { return Ok(false); @@ -113,7 +128,7 @@ impl NormalizationRule for SimplifyFilter { return Ok(false); } } - ConstantCalculator.visit(&mut filter_op.predicate)?; + ConstantCalculator::new(arena).visit(&mut filter_op.predicate)?; Simplify::default().visit(&mut filter_op.predicate)?; filter_op.is_optimized = true; return Ok(true); @@ -126,7 +141,7 @@ impl NormalizationRule for SimplifyFilter { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use crate::binder::test::build_t1_table; - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary}; + use crate::errors::DatabaseError; use crate::expression::range_detacher::{Range, RangeDetacher}; use crate::expression::{BinaryOperator, ScalarExpression, UnaryOperator}; @@ -134,8 +149,7 @@ mod test { use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; use crate::optimizer::rule::normalization::NormalizationRuleImpl; use crate::planner::operator::Operator; - use crate::planner::LogicalPlan; - use crate::storage::rocksdb::RocksTransaction; + use crate::planner::{LogicalPlan, PlanArena}; use crate::types::value::DataValue; use crate::types::{ColumnId, LogicalType}; use std::collections::Bound; @@ -145,20 +159,24 @@ mod test { name: &str, strategy: HepBatchStrategy, rules: Vec, + arena: &mut PlanArena, ) -> Result { HepOptimizerPipeline::builder() .before_batch(name.to_string(), strategy, rules) .build() .instantiate(plan) - .find_best::(None) + .find_best(None, arena) } #[test] fn test_constant_calculation_omitted() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; + let mut arena = PlanArena::new(&table_state.table_arena); // (2 + (-1)) < -(c1 + 1) - let plan = - table_state.plan("select c1 + (2 + 1), 2 + 1 from t1 where (2 + (-1)) < -(c1 + 1)")?; + let plan = table_state.plan_with_arena( + "select c1 + (2 + 1), 2 + 1 from t1 where (2 + (-1)) < -(c1 + 1)", + &mut arena, + )?; let best_plan = HepOptimizerPipeline::builder() .before_batch( @@ -171,7 +189,7 @@ mod test { ) .build() .instantiate(plan) - .find_best::(None)?; + .find_best(None, &mut arena)?; if let Operator::Project(project_op) = best_plan.clone().operator { let constant_expr = ScalarExpression::Constant(DataValue::Int32(3)); if let ScalarExpression::Binary { right_expr, .. } = &project_op.exprs[0] { @@ -185,7 +203,7 @@ mod test { } let filter_op = best_plan.childrens.pop_only(); if let Operator::Filter(filter_op) = filter_op.operator { - let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + let range = RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &arena) .detach(&filter_op.predicate)? .unwrap(); assert_eq!( @@ -205,14 +223,18 @@ mod test { #[test] fn test_constant_cast_elimination() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state - .plan("select cast(1 as int), cast(2 as int) + 1 from t1 where cast(3 as int) = 3")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = table_state.plan_with_arena( + "select cast(1 as int), cast(2 as int) + 1 from t1 where cast(3 as int) = 3", + &mut arena, + )?; let best_plan = run_with_single_batch( plan, "test_constant_cast_elimination", HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::ConstantCalculation], + &mut arena, )?; if let Operator::Project(project_op) = best_plan.operator { @@ -244,40 +266,50 @@ mod test { #[test] fn test_simplify_filter_single_column() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; + let mut arena = PlanArena::new(&table_state.table_arena); // c1 + 1 < -1 => c1 < -2 - let plan_1 = table_state.plan("select * from t1 where -(c1 + 1) > 1")?; + let plan_1 = + table_state.plan_with_arena("select * from t1 where -(c1 + 1) > 1", &mut arena)?; // 1 - c1 < -1 => c1 > 2 - let plan_2 = table_state.plan("select * from t1 where -(1 - c1) > 1")?; + let plan_2 = + table_state.plan_with_arena("select * from t1 where -(1 - c1) > 1", &mut arena)?; // c1 < -1 - let plan_3 = table_state.plan("select * from t1 where -c1 > 1")?; + let plan_3 = table_state.plan_with_arena("select * from t1 where -c1 > 1", &mut arena)?; // c1 > 0 - let plan_4 = table_state.plan("select * from t1 where c1 + 1 > 1")?; + let plan_4 = + table_state.plan_with_arena("select * from t1 where c1 + 1 > 1", &mut arena)?; // c1 + 1 < -1 => c1 < -2 - let plan_5 = table_state.plan("select * from t1 where 1 < -(c1 + 1)")?; + let plan_5 = + table_state.plan_with_arena("select * from t1 where 1 < -(c1 + 1)", &mut arena)?; // 1 - c1 < -1 => c1 > 2 - let plan_6 = table_state.plan("select * from t1 where 1 < -(1 - c1)")?; + let plan_6 = + table_state.plan_with_arena("select * from t1 where 1 < -(1 - c1)", &mut arena)?; // c1 < -1 - let plan_7 = table_state.plan("select * from t1 where 1 < -c1")?; + let plan_7 = table_state.plan_with_arena("select * from t1 where 1 < -c1", &mut arena)?; // c1 > 0 - let plan_8 = table_state.plan("select * from t1 where 1 < c1 + 1")?; + let plan_8 = + table_state.plan_with_arena("select * from t1 where 1 < c1 + 1", &mut arena)?; // c1 < 24 - let plan_9 = table_state.plan("select * from t1 where (-1 - c1) + 1 > 24")?; + let plan_9 = + table_state.plan_with_arena("select * from t1 where (-1 - c1) + 1 > 24", &mut arena)?; // c1 < 24 - let plan_10 = table_state.plan("select * from t1 where 24 < (-1 - c1) + 1")?; + let plan_10 = + table_state.plan_with_arena("select * from t1 where 24 < (-1 - c1) + 1", &mut arena)?; - let op = |plan: LogicalPlan| -> Result, DatabaseError> { + let mut op = |plan: LogicalPlan| -> Result, DatabaseError> { let best_plan = run_with_single_batch( plan, "test_simplify_filter", HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::SimplifyFilter], + &mut arena, )?; let filter_op = best_plan.childrens.pop_only(); if let Operator::Filter(filter_op) = filter_op.operator { Ok( - RangeDetacher::new("t1", table_state.column_id_by_name("c1")) + RangeDetacher::new("t1", table_state.column_id_by_name("c1"), &arena) .detach(&filter_op.predicate)?, ) } else { @@ -309,6 +341,7 @@ mod test { #[test] fn test_simplify_filter_boolean_wrapped_range_comparison() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; + let mut arena = PlanArena::new(&table_state.table_arena); let expected = Some(Range::Scope { min: Bound::Unbounded, max: Bound::Included(DataValue::Int32(10)), @@ -319,9 +352,9 @@ mod test { "select * from t1 where (c1 > 10) != true", "select * from t1 where not (c1 > 10)", ] { - let plan = table_state.plan(sql)?; + let plan = table_state.plan_with_arena(sql, &mut arena)?; assert_eq!( - plan_filter(&plan, table_state.column_id_by_name("c1"))?, + plan_filter(&plan, table_state.column_id_by_name("c1"), &mut arena)?, expected ); } @@ -332,43 +365,22 @@ mod test { #[test] fn test_simplify_filter_repeating_column() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan = table_state.plan("select * from t1 where -(c1 + 1) > c2")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan = + table_state.plan_with_arena("select * from t1 where -(c1 + 1) > c2", &mut arena)?; let best_plan = run_with_single_batch( plan, "test_simplify_filter", HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::SimplifyFilter], + &mut arena, )?; let filter_op = best_plan.childrens.pop_only(); if let Operator::Filter(filter_op) = filter_op.operator { - let c1_col = ColumnCatalog::direct_new( - ColumnSummary { - name: "c1".to_string(), - relation: ColumnRelation::Table { - column_id: *table_state.column_id_by_name("c1"), - table_name: "t1".to_string().into(), - is_temp: false, - }, - }, - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None)?, - false, - ); - let c2_col = ColumnCatalog::direct_new( - ColumnSummary { - name: "c2".to_string(), - relation: ColumnRelation::Table { - column_id: *table_state.column_id_by_name("c2"), - table_name: "t1".to_string().into(), - is_temp: false, - }, - }, - false, - ColumnDesc::new(LogicalType::Integer, None, true, None)?, - false, - ); + let c1_ref = table_state.table.get_column_by_name("c1").unwrap(); + let c2_ref = table_state.table.get_column_by_name("c2").unwrap(); // -(c1 + 1) > c2 => c1 < -c2 - 1 assert_eq!( @@ -379,10 +391,7 @@ mod test { op: UnaryOperator::Minus, expr: Box::new(ScalarExpression::Binary { op: BinaryOperator::Plus, - left_expr: Box::new(ScalarExpression::column_expr( - ColumnRef::from(c1_col), - 0 - )), + left_expr: Box::new(ScalarExpression::column_expr(c1_ref, 0)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, ty: LogicalType::Integer, @@ -390,9 +399,7 @@ mod test { evaluator: None, ty: LogicalType::Integer, }), - right_expr: Box::new( - ScalarExpression::column_expr(ColumnRef::from(c2_col), 1,) - ), + right_expr: Box::new(ScalarExpression::column_expr(c2_ref, 1)), evaluator: None, ty: LogicalType::Boolean, } @@ -407,17 +414,19 @@ mod test { fn plan_filter( plan: &LogicalPlan, column_id: &ColumnId, + arena: &mut PlanArena, ) -> Result, DatabaseError> { let best_plan = run_with_single_batch( plan.clone(), "test_simplify_filter", HepBatchStrategy::once_topdown(), vec![NormalizationRuleImpl::SimplifyFilter], + arena, )?; let filter_op = best_plan.childrens.pop_only(); if let Operator::Filter(filter_op) = filter_op.operator { - Ok(RangeDetacher::new("t1", column_id).detach(&filter_op.predicate)?) + Ok(RangeDetacher::new("t1", column_id, arena).detach(&filter_op.predicate)?) } else { Ok(None) } @@ -426,26 +435,43 @@ mod test { #[test] fn test_simplify_filter_multiple_column() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; + let mut arena = PlanArena::new(&table_state.table_arena); // c1 + 1 < -1 => c1 < -2 - let plan_1 = table_state.plan("select * from t1 where -(c1 + 1) > 1 and -(1 - c2) > 1")?; + let plan_1 = table_state.plan_with_arena( + "select * from t1 where -(c1 + 1) > 1 and -(1 - c2) > 1", + &mut arena, + )?; // 1 - c1 < -1 => c1 > 2 - let plan_2 = table_state.plan("select * from t1 where -(1 - c1) > 1 and -(c2 + 1) > 1")?; + let plan_2 = table_state.plan_with_arena( + "select * from t1 where -(1 - c1) > 1 and -(c2 + 1) > 1", + &mut arena, + )?; // c1 < -1 - let plan_3 = table_state.plan("select * from t1 where -c1 > 1 and c2 + 1 > 1")?; + let plan_3 = table_state + .plan_with_arena("select * from t1 where -c1 > 1 and c2 + 1 > 1", &mut arena)?; // c1 > 0 - let plan_4 = table_state.plan("select * from t1 where c1 + 1 > 1 and -c2 > 1")?; + let plan_4 = table_state + .plan_with_arena("select * from t1 where c1 + 1 > 1 and -c2 > 1", &mut arena)?; - let range_1_c1 = plan_filter(&plan_1, table_state.column_id_by_name("c1"))?.unwrap(); - let range_1_c2 = plan_filter(&plan_1, table_state.column_id_by_name("c2"))?.unwrap(); + let range_1_c1 = + plan_filter(&plan_1, table_state.column_id_by_name("c1"), &mut arena)?.unwrap(); + let range_1_c2 = + plan_filter(&plan_1, table_state.column_id_by_name("c2"), &mut arena)?.unwrap(); - let range_2_c1 = plan_filter(&plan_2, table_state.column_id_by_name("c1"))?.unwrap(); - let range_2_c2 = plan_filter(&plan_2, table_state.column_id_by_name("c2"))?.unwrap(); + let range_2_c1 = + plan_filter(&plan_2, table_state.column_id_by_name("c1"), &mut arena)?.unwrap(); + let range_2_c2 = + plan_filter(&plan_2, table_state.column_id_by_name("c2"), &mut arena)?.unwrap(); - let range_3_c1 = plan_filter(&plan_3, table_state.column_id_by_name("c1"))?.unwrap(); - let range_3_c2 = plan_filter(&plan_3, table_state.column_id_by_name("c2"))?.unwrap(); + let range_3_c1 = + plan_filter(&plan_3, table_state.column_id_by_name("c1"), &mut arena)?.unwrap(); + let range_3_c2 = + plan_filter(&plan_3, table_state.column_id_by_name("c2"), &mut arena)?.unwrap(); - let range_4_c1 = plan_filter(&plan_4, table_state.column_id_by_name("c1"))?.unwrap(); - let range_4_c2 = plan_filter(&plan_4, table_state.column_id_by_name("c2"))?.unwrap(); + let range_4_c1 = + plan_filter(&plan_4, table_state.column_id_by_name("c1"), &mut arena)?.unwrap(); + let range_4_c2 = + plan_filter(&plan_4, table_state.column_id_by_name("c2"), &mut arena)?.unwrap(); assert_eq!( range_1_c1, @@ -510,11 +536,13 @@ mod test { #[test] fn test_simplify_filter_multiple_column_in_or() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; + let mut arena = PlanArena::new(&table_state.table_arena); // c1 > c2 or c1 > 1 - let plan_1 = table_state.plan("select * from t1 where c1 > c2 or c1 > 1")?; + let plan_1 = + table_state.plan_with_arena("select * from t1 where c1 > c2 or c1 > 1", &mut arena)?; assert_eq!( - plan_filter(&plan_1, table_state.column_id_by_name("c1"))?, + plan_filter(&plan_1, table_state.column_id_by_name("c1"), &mut arena)?, None ); @@ -524,10 +552,14 @@ mod test { #[test] fn test_simplify_filter_multiple_dispersed_same_column_in_or() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan_1 = table_state.plan("select * from t1 where c1 = 4 and c1 > c2 or c1 > 1")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan_1 = table_state.plan_with_arena( + "select * from t1 where c1 = 4 and c1 > c2 or c1 > 1", + &mut arena, + )?; assert_eq!( - plan_filter(&plan_1, table_state.column_id_by_name("c1"))?, + plan_filter(&plan_1, table_state.column_id_by_name("c1"), &mut arena)?, Some(Range::Scope { min: Bound::Excluded(DataValue::Int32(1)), max: Bound::Unbounded, @@ -540,10 +572,12 @@ mod test { #[test] fn test_simplify_filter_column_is_null() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan_1 = table_state.plan("select * from t1 where c1 is null")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan_1 = + table_state.plan_with_arena("select * from t1 where c1 is null", &mut arena)?; assert_eq!( - plan_filter(&plan_1, table_state.column_id_by_name("c1"))?, + plan_filter(&plan_1, table_state.column_id_by_name("c1"), &mut arena)?, Some(Range::Eq(DataValue::Null)) ); @@ -553,10 +587,12 @@ mod test { #[test] fn test_simplify_filter_column_is_not_null() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan_1 = table_state.plan("select * from t1 where c1 is not null")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan_1 = + table_state.plan_with_arena("select * from t1 where c1 is not null", &mut arena)?; assert_eq!( - plan_filter(&plan_1, table_state.column_id_by_name("c1"))?, + plan_filter(&plan_1, table_state.column_id_by_name("c1"), &mut arena)?, None ); @@ -566,10 +602,12 @@ mod test { #[test] fn test_simplify_filter_column_in() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan_1 = table_state.plan("select * from t1 where c1 in (1, 2, 3)")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan_1 = + table_state.plan_with_arena("select * from t1 where c1 in (1, 2, 3)", &mut arena)?; assert_eq!( - plan_filter(&plan_1, table_state.column_id_by_name("c1"))?, + plan_filter(&plan_1, table_state.column_id_by_name("c1"), &mut arena)?, Some(Range::SortedRanges(vec![ Range::Eq(DataValue::Int32(1)), Range::Eq(DataValue::Int32(2)), @@ -583,10 +621,12 @@ mod test { #[test] fn test_simplify_filter_column_not_in() -> Result<(), DatabaseError> { let table_state = build_t1_table()?; - let plan_1 = table_state.plan("select * from t1 where c1 not in (1, 2, 3)")?; + let mut arena = PlanArena::new(&table_state.table_arena); + let plan_1 = table_state + .plan_with_arena("select * from t1 where c1 not in (1, 2, 3)", &mut arena)?; assert_eq!( - plan_filter(&plan_1, table_state.column_id_by_name("c1"))?, + plan_filter(&plan_1, table_state.column_id_by_name("c1"), &mut arena)?, None ); diff --git a/src/optimizer/rule/normalization/top_k.rs b/src/optimizer/rule/normalization/top_k.rs index 193ec959..73cd74a1 100644 --- a/src/optimizer/rule/normalization/top_k.rs +++ b/src/optimizer/rule/normalization/top_k.rs @@ -22,7 +22,11 @@ use crate::planner::LogicalPlan; pub struct TopK; impl NormalizationRule for TopK { - fn apply(&self, plan: &mut LogicalPlan) -> Result { + fn apply( + &self, + plan: &mut LogicalPlan, + _: &mut crate::planner::PlanArena, + ) -> Result { let (offset, limit) = match &plan.operator { Operator::Limit(op) => match op.limit { Some(limit) => (op.offset, limit), diff --git a/src/orm/README.md b/src/orm/README.md index dc44a011..bfff6424 100644 --- a/src/orm/README.md +++ b/src/orm/README.md @@ -1,31 +1,19 @@ # ORM -KiteSQL provides a built-in ORM behind `features = ["orm"]`. - -The ORM is centered around `#[derive(Model)]`. It generates: - -- tuple-to-struct mapping -- cached model statements -- cached DDL statements -- migration metadata -- typed field accessors for query building - -## Enabling the feature +KiteSQL's ORM is available with `features = ["orm"]`. This also enables the +derive macros used by `#[derive(Model)]`; `#[derive(Projection)]` is optional +for DTO-style projections. ```toml kite_sql = { version = "*", features = ["orm"] } ``` -If you also want to derive the model macro, enable `macros` as well: +## Model -```toml -kite_sql = { version = "*", features = ["orm", "macros"] } -``` +`#[derive(Model)]` defines the table mapping, cached model operations, typed +field accessors, and migration metadata. -## Quick start - -```rust -use kite_sql::db::DataBaseBuilder; +```rust,ignore use kite_sql::Model; #[derive(Default, Debug, PartialEq, Model)] @@ -39,210 +27,170 @@ struct User { #[model(default = "18", index)] age: Option, } +``` -let database = DataBaseBuilder::path(".").build_in_memory()?; -database.create_table::()?; +Common field attributes are `primary_key`, `unique`, `index`, `rename`, +`default`, `varchar`, `char`, `decimal_precision`, `decimal_scale`, and `skip`. + +## Queries +The recommended query entrypoint is `bind`. The closure receives an +`OrmContext`, drives the binder directly, and returns a `LogicalPlan`. Query +chains end with `.finish()`; mutation chains end with `update`, `delete`, or +another operation that already returns a plan. + +```rust,ignore +use kite_sql::db::{DataBaseBuilder, ResultIter}; +use kite_sql::orm::OrmQueryResultExt; + +let mut database = DataBaseBuilder::path(".").build_in_memory()?; +database.create_table::()?; database.insert(&User { id: 1, name: "Alice".to_string(), age: Some(18), })?; -let user = database.get::(&1)?.unwrap(); -assert_eq!(user.name, "Alice"); - let adults = database - .from::() - .gte(User::age(), 18) - .asc(User::name()) - .fetch()?; - -for user in adults { - println!("{:?}", user?); -} + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let adult = e.column(User::age())?.gte(18)?; + let named_a = e.column(User::name())?.like("A%")?; + adult.and(named_a) + })? + .finish() + })? + .orm::() + .collect::, _>>()?; + +assert_eq!(adults[0].name, "Alice"); # Ok::<(), Box>(()) ``` -## Model derive - -`#[derive(Model)]` is the intended entry point for ORM models. - -Struct attributes: - -- `#[model(table = "users")]`: sets the backing table name -- `#[model(index(name = "idx", columns = "a, b"))]`: declares a secondary index at the model level - -Field attributes: - -- `#[model(primary_key)]` -- `#[model(unique)]` -- `#[model(index)]` -- `#[model(rename = "column_name")]` -- `#[model(default = "18")]` -- `#[model(varchar = 64)]` -- `#[model(char = 2)]` -- `#[model(decimal_precision = 10, decimal_scale = 2)]` -- `#[model(skip)]` - -The derive macro generates the `Model` implementation, tuple decoding, cached -read/insert/DDL statements, migration metadata, and typed field getters such as -`User::id()` and `User::name()`. - -## Query Builder - -`Database::from::()` and `DBTransaction::from::()` start a typed query -from one ORM model table. - -The usual flow is: - -- start with `from::()` -- add filters, joins, grouping, ordering, and limits -- keep full-model output, or switch into `project::

()`, - `project_value(...)`, or `project_tuple(...)` -- once the output shape is fixed, compose set queries with `union(...)`, - `except(...)`, `intersect(...)`, and optional `.all()` - -If you need an explicit relation alias, call `.alias("name")` on a source or -pending join, and re-qualify fields with `Field::qualify("name")` where -needed. For ordinary multi-table queries, `inner_join::().on(...)`, -`left_join::().on(...)`, `right_join::().on(...)`, -`full_join::().on(...)`, `cross_join::()`, and `using(...)` cover most -cases. - -Most expression building starts from generated fields such as `User::id()` and -`User::name()`. Field values support arithmetic, comparison, null checks, -pattern matching, range checks, casts, aliases, and subquery predicates. For -computed expressions, use `QueryValue` helpers such as `func`, `count`, -`count_all`, `sum`, `avg`, `min`, `max`, `case_when`, and `case_value`. - -Boolean composition lives on `QueryExpr` through `and`, `or`, `not`, `exists`, -and `not_exists`. - -### Projections - -Use full-model fetches when the query still matches `M`, or switch to one of -the projection modes: - -- `project::

()`: decode rows into a DTO-style struct -- `project_value(...)`: decode one expression per row into a scalar type -- `project_tuple(...)`: decode multiple expressions positionally into a tuple - -For `project::

()`, `P` is typically a `#[derive(Projection)]` type whose -field names match the output names. Use `#[projection(rename = "...")]` to map -DTO fields to differently named source columns, and `#[projection(from = "...")]` -for join projections that need an explicit source relation. - -If the output is expression-based, prefer `project_value(...)` or -`project_tuple(...)` and assign explicit names with `.alias(...)`. - -### Set queries - -Set operations are available after the output shape is fixed: - -- model rows: `from::().union(...)` -- single values: `project_value(...).union(...)` -- tuples: `project_tuple(...).except(...)` -- intersections: `project_value(...).intersect(...)` -- struct projections: `project::

().union(...)` +Inside expression closures, `e.column(User::id())?` resolves through the core +binder and returns a bound expression. Expression methods such as `eq`, `gte`, +`like`, `and`, `or`, `is_null`, and `in_list` compose directly into core +`ScalarExpression` values; constants can be passed directly. + +Prefer the compact helpers when the query shape is simple: + +```rust,ignore +let rows = database + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::age())?.gte(18))? + .order_by(User::age().desc())? + .project_scalars((User::id(), User::name()))? + .finish() + })? + .project_tuple::<(i32, String)>() + .collect::, _>>()?; +# let _ = rows; +# Ok::<(), Box>(()) +``` -Call `.all()` after `union(...)`, `except(...)`, or `intersect(...)` when you want multiset -semantics instead of the default distinct result. +Use `project_scalar(...)` for one field, `project_scalars((...))` for simple +tuples, and `project_value/project_tuple` when the projection is an expression +or needs aliases. Use `order_by` for field ordering and already-bound sort +fields, or `order_by_expr` for computed sort expressions. Call `.asc()`, +`.desc()`, `.nulls_first()`, or `.nulls_last()` when the default +ascending/nulls-last order is not enough. + +Joins and set operations use the same binder-backed style: + +```rust,ignore +let joined = database + .bind(|ctx| { + ctx.from::()? + .inner_join::(|e| { + e.column(User::id())?.eq(e.column(Order::user_id())?) + })? + .project_scalars((User::name(), Order::amount()))? + .order_by(Order::id())? + .finish() + })?; + +let ids = database.bind(|ctx| { + ctx.union( + true, + |ctx| ctx.from::()?.project_scalar(User::id())?.finish(), + |ctx| ctx.from::()?.project_scalar(Order::user_id())?.finish(), + ) +})?; +# let _ = (joined, ids); +# Ok::<(), Box>(()) +``` -After a set query is formed, you can still apply result-level methods such as -`asc(...)`, `desc(...)`, `nulls_first()`, `nulls_last()`, `limit(...)`, -`offset(...)`, `fetch()`, `get()`, `exists()`, `count()`, and `explain()`. +## Writes -Tips: `nulls_first()` and `nulls_last()` only affect the most recently added -sort key from `asc(...)` or `desc(...)`. +For model rows, use the direct helpers: -For richer combinations such as join projections, grouping, scalar subqueries, -and set queries, prefer the rustdoc on `FromBuilder`, `SetQueryBuilder`, -`Field`, `QueryValue`, and `QueryExpr`. +```rust,ignore +database.insert(&user)?; +database.insert_many(users)?; +let user = database.get::(&1)?; +let all = database.fetch::()?; +# let _ = (user, all); +# Ok::<(), Box>(()) +``` -## Change Operations +For query-shaped writes, start with `ctx.mutate::()` and finish with +`update` or `delete`. + +```rust,ignore +database + .bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(User::id())?.eq(1))? + .update(|u| { + u.set_value(User::name(), "Bob")?; + u.set_value(User::age(), None::) + }) + })? + .done()?; + +database + .bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(User::id())?.eq(2))? + .delete() + })? + .done()?; +# Ok::<(), Box>(()) +``` -The ORM supports both schema changes and data changes. +`insert_select` and `overwrite_select` accept the same closure style for the +source plan: + +```rust,ignore +database + .bind(|ctx| { + ctx.insert_select::(["id", "user_name"], |ctx| { + ctx.from::()? + .project_scalars((User::id(), User::name()))? + .finish() + }) + })? + .done()?; +# Ok::<(), Box>(()) +``` -### Schema changes +## Schema And Maintenance -On `Database`: +Common schema helpers are: - `create_table::()` - `create_table_if_not_exists::()` - `migrate::()` -- `drop_index::(index_name)` -- `drop_index_if_exists::(index_name)` - `drop_table::()` - `drop_table_if_exists::()` - `truncate::()` -- `create_view(name, query_builder)` -- `create_or_replace_view(name, query_builder)` -- `drop_view(name)` -- `drop_view_if_exists(name)` - -`DBTransaction` does not currently expose the ORM DDL convenience methods. - -Typical schema maintenance uses the same model types and query builders: -create tables from `Model`, truncate by model, and create or replace views from -ORM queries. - -### Data changes - -For common model-oriented writes: - -- `insert::(&model)` -- `insert_many::(models)` - -For query-driven writes, reuse the same filtered `from::()` entrypoint and -finish with: - -- `insert::()` -- `insert_into::(...)` -- `overwrite::()` -- `overwrite_into::(...)` -- `update().set(...).execute()` -- `delete()` - -Here `overwrite*` follows the engine's `INSERT OVERWRITE` semantics, meaning -conflicting target rows are replaced rather than the whole table being cleared. - -For model-oriented writes, use `insert` and `insert_many`. For query-driven -writes, compose from `from::()` and finish with `insert`, `overwrite`, -`update`, or `delete`. - -Query-driven writes are intentionally shaped like read queries first, so the -same filters, joins, and projections can flow into the final write operation. - -## Introspection / Maintenance - -The ORM also exposes light-weight introspection and maintenance helpers. - -On `Database`: - -- `show_tables()` -- `show_views()` -- `describe::()` -- `from::()...explain()` -- `analyze::()` - -On `DBTransaction`: - -- `show_tables()` -- `show_views()` -- `describe::()` - -These helpers are intended for light-weight inspection around ORM-managed -tables, without dropping down to raw SQL for common metadata queries. - -## Further reading +- `create_view(...)` / `create_or_replace_view(...)` +- `drop_view(...)` / `drop_view_if_exists(...)` -Detailed method-by-method examples live in the rustdoc for: +Introspection helpers include `show_tables()`, `show_views()`, `describe::()`, +and `analyze::()`. -- `Field` -- `QueryValue` -- `QueryExpr` -- `FromBuilder` -- `SetQueryBuilder` -- `Projection` -- `Model` +The ORM frontend does not build SQL AST nodes. SQL parsing is the SQL frontend; +ORM queries bind directly into `ScalarExpression` and `LogicalPlan`. diff --git a/src/orm/ddl.rs b/src/orm/ddl.rs index 3a914475..5cafff4e 100644 --- a/src/orm/ddl.rs +++ b/src/orm/ddl.rs @@ -31,11 +31,11 @@ impl Database { /// let database = DataBaseBuilder::path(".").build_in_memory().unwrap(); /// database.create_table::().unwrap(); /// ``` - pub fn create_table(&self) -> Result<(), DatabaseError> { - self.execute(M::create_table_statement(), &[])?.done()?; + pub fn create_table(&mut self) -> Result<(), DatabaseError> { + execute_create_table::<_, M>(self, false)?; - for statement in M::create_index_statements() { - self.execute(statement, &[])?.done()?; + for index in M::indexes() { + execute_create_index(self, M::table_name(), index, false)?; } Ok(()) @@ -46,12 +46,11 @@ impl Database { /// This is useful for examples, tests and bootstrap flows where rerunning /// schema initialization should stay idempotent. Secondary indexes declared /// with `#[model(index)]` are created with `IF NOT EXISTS` as well. - pub fn create_table_if_not_exists(&self) -> Result<(), DatabaseError> { - self.execute(M::create_table_if_not_exists_statement(), &[])? - .done()?; + pub fn create_table_if_not_exists(&mut self) -> Result<(), DatabaseError> { + execute_create_table::<_, M>(self, true)?; - for statement in M::create_index_if_not_exists_statements() { - self.execute(statement, &[])?.done()?; + for index in M::indexes() { + execute_create_index(self, M::table_name(), index, true)?; } Ok(()) @@ -78,46 +77,55 @@ impl Database { /// assert_eq!(database.fetch::().unwrap().count(), 0); /// ``` pub fn truncate(&self) -> Result<(), DatabaseError> { - self.execute(&orm_truncate_statement(M::table_name()), &[])? - .done() + self.bind(|ctx| ctx.truncate::())?.done() } - /// Creates a view from an ORM query builder. - pub fn create_view( - &self, - view_name: &str, - query: Q, - ) -> Result<(), DatabaseError> { - self.execute( - &orm_create_view_statement(view_name, query.into_subquery(), false), - &[], - )? - .done() + /// Creates a view from a binder-backed ORM plan builder. + pub fn create_view(&mut self, view_name: &str, build: F) -> Result<(), DatabaseError> + where + F: for<'ctx, 'bind, 'parent, 'arena> FnOnce( + &'ctx mut OrmContext< + 'ctx, + 'bind, + 'parent, + 'arena, + S::TransactionType<'_>, + &'static [(&'static str, DataValue)], + >, + ) -> Result, + { + execute_create_view(self, view_name, build, false) } - /// Creates or replaces a view from an ORM query builder. - pub fn create_or_replace_view( - &self, + /// Creates or replaces a view from a binder-backed ORM plan builder. + pub fn create_or_replace_view( + &mut self, view_name: &str, - query: Q, - ) -> Result<(), DatabaseError> { - self.execute( - &orm_create_view_statement(view_name, query.into_subquery(), true), - &[], - )? - .done() + build: F, + ) -> Result<(), DatabaseError> + where + F: for<'ctx, 'bind, 'parent, 'arena> FnOnce( + &'ctx mut OrmContext< + 'ctx, + 'bind, + 'parent, + 'arena, + S::TransactionType<'_>, + &'static [(&'static str, DataValue)], + >, + ) -> Result, + { + execute_create_view(self, view_name, build, true) } /// Drops a view by name. - pub fn drop_view(&self, view_name: &str) -> Result<(), DatabaseError> { - self.execute(&orm_drop_view_statement(view_name, false), &[])? - .done() + pub fn drop_view(&mut self, view_name: &str) -> Result<(), DatabaseError> { + execute_drop_view(self, view_name, false) } /// Drops a view by name if it exists. - pub fn drop_view_if_exists(&self, view_name: &str) -> Result<(), DatabaseError> { - self.execute(&orm_drop_view_statement(view_name, true), &[])? - .done() + pub fn drop_view_if_exists(&mut self, view_name: &str) -> Result<(), DatabaseError> { + execute_drop_view(self, view_name, true) } /// Migrates an existing table to match the current model definition. @@ -131,7 +139,7 @@ impl Database { /// to type, nullability and default expressions for non-primary-key columns /// when the underlying DDL supports them. Primary-key changes and unique /// constraint changes still return an error so you can handle them manually. - pub fn migrate(&self) -> Result<(), DatabaseError> { + pub fn migrate(&mut self) -> Result<(), DatabaseError> { let columns = M::columns(); if columns.is_empty() { return Err(DatabaseError::UnsupportedStmt( @@ -143,131 +151,171 @@ impl Database { let Some(table) = self.table_catalog(M::table_name())? else { return self.create_table::(); }; + let (table_primary_key, current_columns) = { + let table_arena = self.state.table_arena().borrow(); + let table_primary_key = table + .primary_keys() + .first() + .map(|(_, column)| table_arena.column(*column).clone()) + .ok_or(DatabaseError::PrimaryKeyNotFound)?; + let current_columns = table + .columns() + .map(|column| { + let column = table_arena.column(*column).clone(); + (column.name().to_string(), column) + }) + .collect::>(); + (table_primary_key, current_columns) + }; let model_primary_key = columns .iter() - .find(|column| column.primary_key) - .ok_or(DatabaseError::PrimaryKeyNotFound)?; - let table_primary_key = table - .primary_keys() - .first() - .map(|(_, column)| column.clone()) + .find(|column| column.desc().is_primary()) .ok_or(DatabaseError::PrimaryKeyNotFound)?; - if table_primary_key.name() != model_primary_key.name - || !model_column_matches_catalog(model_primary_key, &table_primary_key) + if table_primary_key.name() != model_primary_key.name() + || !model_column_matches_catalog(model_primary_key, &table_primary_key)? { return Err(DatabaseError::InvalidValue(::std::format!( "ORM migration does not support changing the primary key for table `{}`", M::table_name(), ))); } - - let current_columns = table - .columns() - .map(|column| (column.name().to_string(), column.clone())) - .collect::>(); let model_columns = columns .iter() - .map(|column| (column.name, column)) + .map(|column| (column.name(), column)) .collect::>(); let mut handled_current = BTreeMap::new(); let mut handled_model = BTreeMap::new(); for column in columns { - let Some(current_column) = current_columns.get(column.name) else { + let Some(current_column) = current_columns.get(column.name()) else { continue; }; handled_current.insert(current_column.name().to_string(), ()); - handled_model.insert(column.name, ()); + handled_model.insert(column.name(), ()); - if column.primary_key != current_column.desc().is_primary() { + if column.desc().is_primary() != current_column.desc().is_primary() { return Err(DatabaseError::InvalidValue(::std::format!( "ORM migration does not support changing the primary key for table `{}`", M::table_name(), ))); } - if column.unique != current_column.desc().is_unique() { + if column.desc().is_unique() != current_column.desc().is_unique() { return Err(DatabaseError::InvalidValue(::std::format!( "ORM migration cannot automatically change unique constraint on column `{}` of table `{}`", - column.name, + column.name(), M::table_name(), ))); } - if model_column_matches_catalog(column, current_column) { + if model_column_matches_catalog(column, current_column)? { continue; } if !model_column_type_matches_catalog(column, current_column) { - let statement = orm_alter_column_type_statement( + execute_change_column( + self, M::table_name(), - column.name, - &column.ddl_type, + column.name(), + column.name(), + column.datatype().clone(), + DefaultChange::NoChange, + NotNullChange::NoChange, )?; - self.execute(&statement, &[])?.done()?; } - if model_column_default(column) != catalog_column_default(current_column) { - let statement = orm_alter_column_default_statement( + if model_column_default(column)? != catalog_column_default(current_column)? { + execute_change_column( + self, M::table_name(), - column.name, - column.default_expr, + column.name(), + column.name(), + column.datatype().clone(), + match column.desc().default.clone() { + Some(expr) => DefaultChange::Set(expr), + None => DefaultChange::Drop, + }, + NotNullChange::NoChange, )?; - self.execute(&statement, &[])?.done()?; } - if column.nullable != current_column.nullable() { - let statement = orm_alter_column_nullability_statement( + if column.nullable() != current_column.nullable() { + execute_change_column( + self, M::table_name(), - column.name, - column.nullable, - ); - self.execute(&statement, &[])?.done()?; + column.name(), + column.name(), + column.datatype().clone(), + DefaultChange::NoChange, + if column.nullable() { + NotNullChange::Drop + } else { + NotNullChange::Set + }, + )?; } } let mut rename_pairs = Vec::new(); let unmatched_model_columns = columns .iter() - .filter(|column| !handled_model.contains_key(column.name)) + .filter(|column| !handled_model.contains_key(column.name())) .collect::>(); - let unmatched_current_columns = table - .columns() + let unmatched_current_columns = current_columns + .values() .filter(|column| !handled_current.contains_key(column.name())) + .cloned() .collect::>(); for model_column in &unmatched_model_columns { - if model_column.primary_key { + if model_column.desc().is_primary() { continue; } - let candidates = unmatched_current_columns + let mut candidates = Vec::new(); + for column in unmatched_current_columns .iter() - .copied() .filter(|column| !column.desc().is_primary()) - .filter(|column| model_column_rename_compatible(model_column, column)) - .collect::>(); + { + if model_column_rename_compatible(model_column, column)? { + candidates.push(column); + } + } if candidates.len() != 1 { continue; } let current_column = candidates[0]; - let reverse_candidates = unmatched_model_columns + let mut reverse_candidates = Vec::new(); + for other in unmatched_model_columns .iter() - .filter(|other| !other.primary_key) - .filter(|other| model_column_rename_compatible(other, current_column)) - .collect::>(); + .filter(|other| !other.desc().is_primary()) + { + if model_column_rename_compatible(other, current_column)? { + reverse_candidates.push(other); + } + } if reverse_candidates.len() != 1 { continue; } - rename_pairs.push((current_column.name().to_string(), model_column.name)); + rename_pairs.push((current_column.name().to_string(), model_column.name())); handled_current.insert(current_column.name().to_string(), ()); - handled_model.insert(model_column.name, ()); + handled_model.insert(model_column.name(), ()); } for (old_name, new_name) in rename_pairs { - let statement = orm_rename_column_statement(M::table_name(), &old_name, new_name); - self.execute(&statement, &[])?.done()?; + let current_column = current_columns + .get(&old_name) + .ok_or_else(|| DatabaseError::column_not_found(old_name.clone()))?; + execute_change_column( + self, + M::table_name(), + &old_name, + new_name, + current_column.datatype().clone(), + DefaultChange::NoChange, + NotNullChange::NoChange, + )?; } - for column in table.columns() { + for column in current_columns.values() { if handled_current.contains_key(column.name()) || model_columns.contains_key(column.name()) { @@ -281,29 +329,28 @@ impl Database { ))); } - let statement = orm_drop_column_statement(M::table_name(), column.name()); - self.execute(&statement, &[])?.done()?; + execute_drop_column(self, M::table_name(), column.name())?; } for column in columns { - if handled_model.contains_key(column.name) || current_columns.contains_key(column.name) + if handled_model.contains_key(column.name()) + || current_columns.contains_key(column.name()) { continue; } - if column.primary_key { + if column.desc().is_primary() { return Err(DatabaseError::InvalidValue(::std::format!( "ORM migration cannot add a new primary key column `{}` to an existing table `{}`", - column.name, + column.name(), M::table_name(), ))); } - let statement = orm_add_column_statement(M::table_name(), column)?; - self.execute(&statement, &[])?.done()?; + execute_add_column(self, M::table_name(), column)?; } - for statement in M::create_index_if_not_exists_statements() { - self.execute(statement, &[])?.done()?; + for index in M::indexes() { + execute_create_index(self, M::table_name(), index, true)?; } Ok(()) @@ -313,17 +360,16 @@ impl Database { /// /// Primary-key indexes are managed by the table definition itself and /// cannot be dropped independently. - pub fn drop_index(&self, index_name: &str) -> Result<(), DatabaseError> { - let statement = orm_drop_index_statement(M::table_name(), index_name, false); - - self.execute(&statement, &[])?.done() + pub fn drop_index(&mut self, index_name: &str) -> Result<(), DatabaseError> { + execute_drop_index(self, M::table_name(), index_name, false) } /// Drops a non-primary-key model index by name if it exists. - pub fn drop_index_if_exists(&self, index_name: &str) -> Result<(), DatabaseError> { - let statement = orm_drop_index_statement(M::table_name(), index_name, true); - - self.execute(&statement, &[])?.done() + pub fn drop_index_if_exists( + &mut self, + index_name: &str, + ) -> Result<(), DatabaseError> { + execute_drop_index(self, M::table_name(), index_name, true) } /// Drops the model table. @@ -346,16 +392,169 @@ impl Database { /// database.create_table::().unwrap(); /// database.drop_table::().unwrap(); /// ``` - pub fn drop_table(&self) -> Result<(), DatabaseError> { - self.execute(M::drop_table_statement(), &[])?.done() + pub fn drop_table(&mut self) -> Result<(), DatabaseError> { + execute_drop_table(self, M::table_name(), false) } /// Drops the model table if it exists. /// /// This variant is convenient for cleanup code that should succeed even if /// the table was already removed. - pub fn drop_table_if_exists(&self) -> Result<(), DatabaseError> { - self.execute(M::drop_table_if_exists_statement(), &[])? - .done() + pub fn drop_table_if_exists(&mut self) -> Result<(), DatabaseError> { + execute_drop_table(self, M::table_name(), true) } } + +fn execute_create_table( + database: &mut Database, + if_not_exists: bool, +) -> Result<(), DatabaseError> { + let columns = M::columns().to_vec(); + database.execute_mut("ORM CREATE TABLE", &[], move |binder, _| { + binder.bind_create_table(M::table_name().into(), columns, if_not_exists) + }) +} + +fn execute_create_index( + database: &mut Database, + table_name: &'static str, + index: &(&'static str, &'static [&'static str], bool), + if_not_exists: bool, +) -> Result<(), DatabaseError> { + let (index_name, index_columns, unique) = *index; + let index_name = index_name.to_string(); + let column_names = index_columns.to_vec(); + database.execute_mut("ORM CREATE INDEX", &[], move |binder, arena| { + let mut input = binder.bind_create_index_source(table_name.into(), arena)?; + let schema = input.output_schema(arena).clone(); + let mut columns = Vec::with_capacity(column_names.len()); + for column_name in column_names { + let column = schema + .iter() + .copied() + .find(|column| arena.column(*column).name() == column_name) + .ok_or_else(|| DatabaseError::column_not_found(column_name.to_string()))?; + columns.push(column); + } + binder.bind_create_index( + table_name.into(), + index_name, + columns, + if_not_exists, + unique, + input, + ) + }) +} + +fn execute_create_view( + database: &mut Database, + view_name: &str, + build: F, + or_replace: bool, +) -> Result<(), DatabaseError> +where + F: for<'ctx, 'bind, 'parent, 'arena> FnOnce( + &'ctx mut OrmContext< + 'ctx, + 'bind, + 'parent, + 'arena, + S::TransactionType<'_>, + &'static [(&'static str, DataValue)], + >, + ) -> Result, +{ + static EMPTY_ORM_PARAMS: &[(&str, DataValue)] = &[]; + let view_name = view_name.to_string(); + database.execute_mut("ORM CREATE VIEW", EMPTY_ORM_PARAMS, move |binder, arena| { + let mut context = OrmContext { binder, arena }; + let plan = build(&mut context)?; + binder.bind_create_view( + view_name.as_str().into(), + or_replace, + plan, + vec![], + vec![], + arena, + ) + }) +} + +fn execute_drop_view( + database: &mut Database, + view_name: &str, + if_exists: bool, +) -> Result<(), DatabaseError> { + let view_name = view_name.to_string(); + database.execute_mut("ORM DROP VIEW", &[], move |binder, _| { + binder.bind_drop_view(view_name.as_str().into(), if_exists) + }) +} + +fn execute_change_column( + database: &mut Database, + table_name: &'static str, + old_column_name: &str, + new_column_name: &str, + data_type: LogicalType, + default_change: DefaultChange, + not_null_change: NotNullChange, +) -> Result<(), DatabaseError> { + let old_column_name = old_column_name.to_string(); + let new_column_name = new_column_name.to_string(); + database.execute_mut("ORM CHANGE COLUMN", &[], move |binder, _| { + binder.bind_change_column( + table_name.into(), + old_column_name, + new_column_name, + data_type, + default_change, + not_null_change, + ) + }) +} + +fn execute_drop_column( + database: &mut Database, + table_name: &'static str, + column_name: &str, +) -> Result<(), DatabaseError> { + let column_name = column_name.to_string(); + database.execute_mut("ORM DROP COLUMN", &[], move |binder, _| { + binder.bind_drop_column(table_name.into(), column_name, false) + }) +} + +fn execute_add_column( + database: &mut Database, + table_name: &'static str, + column: &ColumnCatalog, +) -> Result<(), DatabaseError> { + let column = column.clone(); + database.execute_mut("ORM ADD COLUMN", &[], move |binder, _| { + binder.bind_add_column(table_name.into(), column, false) + }) +} + +fn execute_drop_index( + database: &mut Database, + table_name: &'static str, + index_name: &str, + if_exists: bool, +) -> Result<(), DatabaseError> { + let index_name = index_name.to_string(); + database.execute_mut("ORM DROP INDEX", &[], move |binder, _| { + binder.bind_drop_index(table_name.into(), index_name, if_exists) + }) +} + +fn execute_drop_table( + database: &mut Database, + table_name: &'static str, + if_exists: bool, +) -> Result<(), DatabaseError> { + database.execute_mut("ORM DROP TABLE", &[], move |binder, _| { + binder.bind_drop_table(table_name.into(), if_exists) + }) +} diff --git a/src/orm/dml.rs b/src/orm/dml.rs index 10e5c709..4adb5cae 100644 --- a/src/orm/dml.rs +++ b/src/orm/dml.rs @@ -6,8 +6,8 @@ impl Database { /// /// This runs `ANALYZE TABLE` for the backing table so the optimizer can use /// up-to-date statistics. - pub fn analyze(&self) -> Result<(), DatabaseError> { - orm_analyze::<_, M>(self) + pub fn analyze_model(&mut self) -> Result<(), DatabaseError> { + self.analyze(M::table_name()) } /// Inserts a model into its backing table. diff --git a/src/orm/dql.rs b/src/orm/dql.rs index 4edf5b87..0a8e1865 100644 --- a/src/orm/dql.rs +++ b/src/orm/dql.rs @@ -1,6 +1,40 @@ use super::*; impl Database { + /// Executes a binder-backed plan built inside a closure. + pub fn bind(&self, build: F) -> Result, DatabaseError> + where + F: for<'ctx, 'bind, 'parent, 'arena> FnOnce( + &'ctx mut OrmContext< + 'ctx, + 'bind, + 'parent, + 'arena, + S::TransactionType<'_>, + &'static [(&'static str, DataValue)], + >, + ) -> Result, + { + bind_orm_context(self, build) + } + + /// Explains a binder-backed plan built inside a closure. + pub fn explain(&self, build: F) -> Result + where + F: for<'ctx, 'bind, 'parent, 'arena> FnOnce( + &'ctx mut OrmContext< + 'ctx, + 'bind, + 'parent, + 'arena, + S::TransactionType<'_>, + &'static [(&'static str, DataValue)], + >, + ) -> Result, + { + explain_orm_context(self, build) + } + /// Loads a single model by primary key. /// /// The key type is taken from `M::PrimaryKey`, so `database.get::(&1)` @@ -58,29 +92,6 @@ impl Database { orm_list::<_, M>(self) } - /// Starts a typed single-table query builder for the given model. - /// - /// ```rust - /// use kite_sql::db::DataBaseBuilder; - /// use kite_sql::Model; - /// - /// #[derive(Default, Debug, PartialEq, Model)] - /// #[model(table = "users")] - /// struct User { - /// #[model(primary_key)] - /// id: i32, - /// name: String, - /// } - /// - /// let database = DataBaseBuilder::path(".").build_in_memory().unwrap(); - /// database.create_table::().unwrap(); - /// let count = database.from::().count().unwrap(); - /// assert_eq!(count, 0); - /// ``` - pub fn from(&self) -> FromBuilder<&Database, M> { - FromBuilder::from_inner(QueryBuilder::new(self)) - } - /// Lists all table names. /// /// ```rust @@ -104,7 +115,7 @@ impl Database { &self, ) -> Result, String>, DatabaseError> { Ok(ProjectValueIter::new( - self.execute(&orm_show_tables_statement(), &[])?, + self.bind(|ctx| ctx.binder.bind_show_tables())?, )) } @@ -113,7 +124,7 @@ impl Database { &self, ) -> Result, String>, DatabaseError> { Ok(ProjectValueIter::new( - self.execute(&orm_show_views_statement(), &[])?, + self.bind(|ctx| ctx.binder.bind_show_views())?, )) } @@ -140,12 +151,49 @@ impl Database { &self, ) -> Result, DescribeColumn>, DatabaseError> { Ok(self - .execute(&orm_describe_statement(M::table_name()), &[])? + .bind(|ctx| ctx.binder.bind_describe(M::table_name().into()))? .orm::()) } } impl<'a, S: Storage> DBTransaction<'a, S> { + /// Executes a binder-backed plan inside the current transaction. + pub fn bind( + &mut self, + build: F, + ) -> Result>, DatabaseError> + where + F: for<'ctx, 'bind, 'parent, 'arena> FnOnce( + &'ctx mut OrmContext< + 'ctx, + 'bind, + 'parent, + 'arena, + S::TransactionType<'a>, + &'static [(&'static str, DataValue)], + >, + ) -> Result, + { + bind_orm_context(self, build) + } + + /// Explains a binder-backed plan inside the current transaction. + pub fn explain(&mut self, build: F) -> Result + where + F: for<'ctx, 'bind, 'parent, 'arena> FnOnce( + &'ctx mut OrmContext< + 'ctx, + 'bind, + 'parent, + 'arena, + S::TransactionType<'a>, + &'static [(&'static str, DataValue)], + >, + ) -> Result, + { + explain_orm_context(self, build) + } + /// Loads a single model by primary key inside the current transaction. pub fn get(&mut self, key: &M::PrimaryKey) -> Result, DatabaseError> { orm_get::<_, M>(self, key) @@ -158,18 +206,13 @@ impl<'a, S: Storage> DBTransaction<'a, S> { orm_list::<_, M>(self) } - /// Starts a typed single-table query builder inside the current transaction. - pub fn from(&mut self) -> FromBuilder<&mut DBTransaction<'a, S>, M> { - FromBuilder::from_inner(QueryBuilder::new(self)) - } - /// Lists all table names inside the current transaction. pub fn show_tables( &mut self, ) -> Result>, String>, DatabaseError> { Ok(ProjectValueIter::new( - self.execute(&orm_show_tables_statement(), &[])?, + self.bind(|ctx| ctx.binder.bind_show_tables())?, )) } @@ -179,7 +222,7 @@ impl<'a, S: Storage> DBTransaction<'a, S> { ) -> Result>, String>, DatabaseError> { Ok(ProjectValueIter::new( - self.execute(&orm_show_views_statement(), &[])?, + self.bind(|ctx| ctx.binder.bind_show_views())?, )) } @@ -189,7 +232,7 @@ impl<'a, S: Storage> DBTransaction<'a, S> { ) -> Result>, DescribeColumn>, DatabaseError> { Ok(self - .execute(&orm_describe_statement(M::table_name()), &[])? + .bind(|ctx| ctx.binder.bind_describe(M::table_name().into()))? .orm::()) } } diff --git a/src/orm/mod.rs b/src/orm/mod.rs index 23f61fae..5c4c17d5 100644 --- a/src/orm/mod.rs +++ b/src/orm/mod.rs @@ -1,37 +1,35 @@ #![doc = include_str!("README.md")] -use crate::catalog::{ColumnRef, TableCatalog}; +use crate::binder::{ + with_query_bind_step, BindPlanFrom, BindPlanSelectList, Binder, JoinConstraintInput, + QueryBindStep, SetOperatorKind, TableAliasInput, +}; +use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableName}; use crate::db::{ - BorrowResultIter, DBTransaction, Database, DatabaseIter, OrmIter, ResultIter, Statement, + BindSource, BorrowResultIter, DBTransaction, Database, DatabaseIter, OrmIter, ResultIter, TransactionIter, }; use crate::errors::DatabaseError; +use crate::expression::{self, AliasType, ScalarExpression}; +use crate::planner::operator::alter_table::change_column::{DefaultChange, NotNullChange}; +use crate::planner::operator::join::JoinType; +use crate::planner::operator::mark_apply::MarkApplyQuantifier; +use crate::planner::operator::sort::SortField; +use crate::planner::{LogicalPlan, PlanArena}; use crate::storage::{Storage, Transaction}; -use crate::types::tuple::{SchemaRef, Tuple}; +use crate::types::tuple::{SchemaView, Tuple}; use crate::types::value::DataValue; use crate::types::CharLengthUnits; use crate::types::LogicalType; -use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +#[cfg(feature = "decimal")] use rust_decimal::Decimal; -use sqlparser::ast::helpers::attached_token::AttachedToken; -use sqlparser::ast::TruncateTableTarget; -use sqlparser::ast::{ - AlterColumnOperation, AlterTable, AlterTableOperation, Analyze, Assignment, AssignmentTarget, - BinaryOperator as SqlBinaryOperator, CaseWhen, CastKind, ColumnDef, ColumnOption, - ColumnOptionDef, CreateIndex, CreateTable, CreateTableOptions, CreateView, DataType, Delete, - DescribeAlias, Distinct, Expr, FromTable, Function, FunctionArg, FunctionArgExpr, - FunctionArgumentList, FunctionArguments, GroupByExpr, HiveDistributionStyle, Ident, - IndexColumn, Insert, Join, JoinConstraint, JoinOperator, KeyOrIndexDisplay, LimitClause, - NullsDistinctOption, ObjectName, ObjectType, Offset, OffsetRows, OrderBy, OrderByExpr, - OrderByKind, OrderByOptions, PrimaryKeyConstraint, Query, Select, SelectFlavor, SelectItem, - SetExpr, SetOperator, SetQuantifier, ShowStatementOptions, TableAlias, TableFactor, - TableObject, TableWithJoins, TimezoneInfo, Truncate, UniqueConstraint, Update, Value, Values, - ViewColumnDef, -}; -use sqlparser::dialect::PostgreSqlDialect; -use sqlparser::parser::Parser; +use std::borrow::Cow; use std::collections::BTreeMap; +use std::fmt; +use std::hash::{Hash, Hasher}; use std::marker::PhantomData; +use std::ptr::NonNull; +use std::rc::Rc; use std::sync::Arc; mod ddl; @@ -45,25 +43,12 @@ mod dql; #[doc(hidden)] pub struct OrmField { pub column: &'static str, + pub column_index: usize, pub placeholder: &'static str, pub primary_key: bool, pub unique: bool, } -#[derive(Debug, Clone, PartialEq, Eq)] -/// Static metadata about a single persisted model column. -/// -/// This is primarily consumed by the built-in ORM migration helper. -#[doc(hidden)] -pub struct OrmColumn { - pub name: &'static str, - pub ddl_type: String, - pub nullable: bool, - pub primary_key: bool, - pub unique: bool, - pub default_expr: Option<&'static str>, -} - /// One row returned by [`Database::describe`] or [`DBTransaction::describe`]. #[derive(Debug, Clone, PartialEq, Eq)] pub struct DescribeColumn { @@ -75,8 +60,8 @@ pub struct DescribeColumn { pub default: String, } -impl From<(&SchemaRef, Tuple)> for DescribeColumn { - fn from((_, tuple): (&SchemaRef, Tuple)) -> Self { +impl From<(&SchemaView<'_, '_>, Tuple)> for DescribeColumn { + fn from((_, tuple): (&SchemaView<'_, '_>, Tuple)) -> Self { let mut values = tuple.values.into_iter(); let field = describe_text_value(values.next()); @@ -100,52 +85,6 @@ impl From<(&SchemaRef, Tuple)> for DescribeColumn { } } -impl OrmColumn { - fn column_def(&self) -> Result { - let mut options = Vec::new(); - - if self.primary_key { - options.push(column_option(ColumnOption::PrimaryKey( - PrimaryKeyConstraint { - name: None, - index_name: None, - index_type: None, - columns: vec![], - index_options: vec![], - characteristics: None, - }, - ))); - } else { - if !self.nullable { - options.push(column_option(ColumnOption::NotNull)); - } - if self.unique { - options.push(column_option(ColumnOption::Unique(UniqueConstraint { - name: None, - index_name: None, - index_type_display: KeyOrIndexDisplay::None, - index_type: None, - columns: vec![], - index_options: vec![], - characteristics: None, - nulls_distinct: NullsDistinctOption::None, - }))); - } - } - if let Some(default_expr) = self.default_expr { - options.push(column_option(ColumnOption::Default(parse_expr_fragment( - default_expr, - )?))); - } - - Ok(ColumnDef { - name: ident(self.name), - data_type: parse_data_type_fragment(&self.ddl_type)?, - options, - }) - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] /// Typed column handle generated for `#[derive(Model)]` query builders. /// @@ -157,28 +96,20 @@ pub struct Field { _marker: PhantomData<(M, T)>, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[doc(hidden)] +pub struct FieldSort { + field: Field, + asc: bool, + nulls_first: bool, +} + #[derive(Debug, Clone, PartialEq, Eq)] struct QuerySource { table_name: String, alias: Option, } -#[derive(Debug, Clone, PartialEq)] -struct JoinSpec { - source: QuerySource, - kind: JoinKind, - constraint: JoinConstraint, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum JoinKind { - Inner, - Left, - Right, - Full, - Cross, -} - impl QuerySource { fn model() -> Self { Self { @@ -187,4764 +118,2340 @@ impl QuerySource { } } - fn relation_name(&self) -> &str { - self.alias.as_deref().unwrap_or(&self.table_name) - } - fn with_alias(mut self, alias: impl Into) -> Self { self.alias = Some(alias.into()); self } } -impl JoinSpec { - fn into_ast(self) -> Join { - let join_operator = match self.kind { - JoinKind::Inner => JoinOperator::Inner(self.constraint), - JoinKind::Left => JoinOperator::Left(self.constraint), - JoinKind::Right => JoinOperator::Right(self.constraint), - JoinKind::Full => JoinOperator::FullOuter(self.constraint), - JoinKind::Cross => JoinOperator::CrossJoin(self.constraint), - }; - - Join { - relation: source_table_factor(&self.source), - global: false, - join_operator, +impl Field { + #[doc(hidden)] + pub const fn new(table: &'static str, column: &'static str) -> Self { + Self { + table, + column, + _marker: PhantomData, } } -} - -trait ValueExpressionOps: Sized { - fn into_query_value(self) -> QueryValue; - fn binary_value_expr>(self, op: SqlBinaryOperator, value: V) -> QueryValue { - QueryValue::from_expr(Expr::BinaryOp { - left: Box::new(self.into_query_value().into_expr()), - op, - right: Box::new(value.into().into_expr()), - }) + pub fn table_name(&self) -> &'static str { + self.table } - fn unary_value_expr(self, op: sqlparser::ast::UnaryOperator) -> QueryValue { - QueryValue::from_expr(Expr::UnaryOp { - op, - expr: Box::new(self.into_query_value().into_expr()), - }) + pub fn column_name(&self) -> &'static str { + self.column } - fn add_expr>(self, value: V) -> QueryValue { - self.binary_value_expr(SqlBinaryOperator::Plus, value) + pub fn asc(self) -> FieldSort { + FieldSort::new(self).asc() } - fn sub_expr>(self, value: V) -> QueryValue { - self.binary_value_expr(SqlBinaryOperator::Minus, value) + pub fn desc(self) -> FieldSort { + FieldSort::new(self).desc() } - fn mul_expr>(self, value: V) -> QueryValue { - self.binary_value_expr(SqlBinaryOperator::Multiply, value) + pub fn nulls_first(self) -> FieldSort { + FieldSort::new(self).nulls_first() } - fn div_expr>(self, value: V) -> QueryValue { - self.binary_value_expr(SqlBinaryOperator::Divide, value) + pub fn nulls_last(self) -> FieldSort { + FieldSort::new(self).nulls_last() } +} - fn modulo_expr>(self, value: V) -> QueryValue { - self.binary_value_expr(SqlBinaryOperator::Modulo, value) +impl FieldSort { + fn new(field: Field) -> Self { + Self { + field, + asc: true, + nulls_first: false, + } } - fn neg_expr(self) -> QueryValue { - self.unary_value_expr(sqlparser::ast::UnaryOperator::Minus) + pub fn asc(mut self) -> Self { + self.asc = true; + self } - fn eq_expr>(self, value: V) -> QueryExpr { - QueryExpr::from_expr(Expr::BinaryOp { - left: Box::new(self.into_query_value().into_expr()), - op: CompareOp::Eq.as_ast(), - right: Box::new(value.into().into_expr()), - }) + pub fn desc(mut self) -> Self { + self.asc = false; + self } - fn ne_expr>(self, value: V) -> QueryExpr { - QueryExpr::from_expr(Expr::BinaryOp { - left: Box::new(self.into_query_value().into_expr()), - op: CompareOp::Ne.as_ast(), - right: Box::new(value.into().into_expr()), - }) + pub fn nulls_first(mut self) -> Self { + self.nulls_first = true; + self } - fn gt_expr>(self, value: V) -> QueryExpr { - QueryExpr::from_expr(Expr::BinaryOp { - left: Box::new(self.into_query_value().into_expr()), - op: CompareOp::Gt.as_ast(), - right: Box::new(value.into().into_expr()), - }) + pub fn nulls_last(mut self) -> Self { + self.nulls_first = false; + self } +} - fn gte_expr>(self, value: V) -> QueryExpr { - QueryExpr::from_expr(Expr::BinaryOp { - left: Box::new(self.into_query_value().into_expr()), - op: CompareOp::Gte.as_ast(), - right: Box::new(value.into().into_expr()), - }) - } +#[doc(hidden)] +pub trait BindOrmScalar<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_scalar( + self, + scope: &mut ExprBindScope<'_, 'bind, 'parent, 'arena, T, A>, + ) -> Result; +} - fn lt_expr>(self, value: V) -> QueryExpr { - QueryExpr::from_expr(Expr::BinaryOp { - left: Box::new(self.into_query_value().into_expr()), - op: CompareOp::Lt.as_ast(), - right: Box::new(value.into().into_expr()), - }) +impl<'bind, 'parent, 'arena, T, A, M, V> BindOrmScalar<'bind, 'parent, 'arena, T, A> for Field +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_scalar( + self, + scope: &mut ExprBindScope<'_, 'bind, 'parent, 'arena, T, A>, + ) -> Result { + scope.column(self).map(CtxExpression::into_scalar) } +} - fn lte_expr>(self, value: V) -> QueryExpr { - QueryExpr::from_expr(Expr::BinaryOp { - left: Box::new(self.into_query_value().into_expr()), - op: CompareOp::Lte.as_ast(), - right: Box::new(value.into().into_expr()), - }) +impl<'bind, 'parent, 'arena, T, A> BindOrmScalar<'bind, 'parent, 'arena, T, A> for ScalarExpression +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_scalar( + self, + _scope: &mut ExprBindScope<'_, 'bind, 'parent, 'arena, T, A>, + ) -> Result { + Ok(self) } +} - fn quantified_subquery_expr( +impl<'bind, 'parent, 'arena, T, A> BindOrmScalar<'bind, 'parent, 'arena, T, A> + for CtxExpression<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_scalar( self, - compare_op: CompareOp, - quantifier: QuantifiedSubquery, - subquery: S, - ) -> QueryExpr { - let left = self.into_query_value().into_expr(); - let right = Expr::Subquery(Box::new(subquery.into_subquery())); - QueryExpr::from_expr(quantifier.into_ast(left, compare_op.as_ast(), right)) + _scope: &mut ExprBindScope<'_, 'bind, 'parent, 'arena, T, A>, + ) -> Result { + Ok(self.into_scalar()) } +} - #[allow(clippy::wrong_self_convention)] - fn is_null_expr(self) -> QueryExpr { - QueryExpr::from_expr(Expr::IsNull(Box::new(self.into_query_value().into_expr()))) - } +#[doc(hidden)] +pub trait BindOrmSort<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_sort<'scope>( + self, + scope: &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result; +} - #[allow(clippy::wrong_self_convention)] - fn is_not_null_expr(self) -> QueryExpr { - QueryExpr::from_expr(Expr::IsNotNull(Box::new( - self.into_query_value().into_expr(), - ))) +impl<'bind, 'parent, 'arena, T, A, M, V> BindOrmSort<'bind, 'parent, 'arena, T, A> for Field +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_sort<'scope>( + self, + scope: &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result { + self.bind_scalar(scope).map(SortField::from) } +} - fn like_expr>(self, pattern: V) -> QueryExpr { - QueryExpr::from_expr(Expr::Like { - negated: false, - expr: Box::new(self.into_query_value().into_expr()), - pattern: Box::new(pattern.into().into_expr()), - escape_char: None, - any: false, - }) +impl<'bind, 'parent, 'arena, T, A, M, V> BindOrmSort<'bind, 'parent, 'arena, T, A> + for FieldSort +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_sort<'scope>( + self, + scope: &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result { + let mut sort = self.field.bind_scalar(scope).map(SortField::from)?; + sort.asc = self.asc; + sort.nulls_first = self.nulls_first; + Ok(sort) } +} - fn not_like_expr>(self, pattern: V) -> QueryExpr { - QueryExpr::from_expr(Expr::Like { - negated: true, - expr: Box::new(self.into_query_value().into_expr()), - pattern: Box::new(pattern.into().into_expr()), - escape_char: None, - any: false, - }) +impl<'bind, 'parent, 'arena, T, A> BindOrmSort<'bind, 'parent, 'arena, T, A> for ScalarExpression +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_sort<'scope>( + self, + _scope: &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result { + Ok(self.into()) } +} - fn in_list_expr(self, values: I) -> QueryExpr - where - I: IntoIterator, - V: Into, - { - QueryExpr::from_expr(Expr::InList { - expr: Box::new(self.into_query_value().into_expr()), - list: values - .into_iter() - .map(Into::into) - .map(QueryValue::into_expr) - .collect(), - negated: false, - }) +impl<'bind, 'parent, 'arena, T, A> BindOrmSort<'bind, 'parent, 'arena, T, A> for SortField +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_sort<'scope>( + self, + _scope: &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result { + Ok(self) } +} - fn not_in_list_expr(self, values: I) -> QueryExpr - where - I: IntoIterator, - V: Into, - { - QueryExpr::from_expr(Expr::InList { - expr: Box::new(self.into_query_value().into_expr()), - list: values - .into_iter() - .map(Into::into) - .map(QueryValue::into_expr) - .collect(), - negated: true, - }) +impl<'bind, 'parent, 'arena, T, A> BindOrmSort<'bind, 'parent, 'arena, T, A> + for CtxExpression<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_sort<'scope>( + self, + _scope: &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result { + Ok(self.into_scalar().into()) } +} - fn between_expr, H: Into>(self, low: L, high: H) -> QueryExpr { - QueryExpr::from_expr(Expr::Between { - expr: Box::new(self.into_query_value().into_expr()), - negated: false, - low: Box::new(low.into().into_expr()), - high: Box::new(high.into().into_expr()), - }) +#[doc(hidden)] +pub trait IntoOrmScalarExpression { + fn into_orm_scalar(self) -> ScalarExpression; +} + +impl IntoOrmScalarExpression for E +where + E: Into, +{ + fn into_orm_scalar(self) -> ScalarExpression { + self.into() } +} - fn not_between_expr, H: Into>( +#[doc(hidden)] +pub trait BindOrmScalarList<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn bind_scalar_list( self, - low: L, - high: H, - ) -> QueryExpr { - QueryExpr::from_expr(Expr::Between { - expr: Box::new(self.into_query_value().into_expr()), - negated: true, - low: Box::new(low.into().into_expr()), - high: Box::new(high.into().into_expr()), - }) - } + scope: &mut ExprBindScope<'_, 'bind, 'parent, 'arena, T, A>, + ) -> Result, DatabaseError>; +} - fn cast_value(self, data_type: &str) -> Result { - Ok(self.cast_to_value(parse_data_type_fragment(data_type)?)) - } +macro_rules! impl_bind_orm_scalar_list { + ($(($($name:ident),+)),+ $(,)?) => { + $( + impl<'bind, 'parent, 'arena, Tx, Args, $($name),+> BindOrmScalarList<'bind, 'parent, 'arena, Tx, Args> + for ($($name,)+) + where + Tx: Transaction, + Args: AsRef<[(&'static str, DataValue)]>, + $($name: BindOrmScalar<'bind, 'parent, 'arena, Tx, Args>,)+ + { + #[allow(non_snake_case)] + fn bind_scalar_list( + self, + scope: &mut ExprBindScope<'_, 'bind, 'parent, 'arena, Tx, Args>, + ) -> Result, DatabaseError> { + let ($($name,)+) = self; + Ok(vec![ + $($name.bind_scalar(scope)?,)+ + ]) + } + } + )+ + }; +} - fn cast_to_value(self, data_type: DataType) -> QueryValue { - QueryValue::from_expr(Expr::Cast { - kind: CastKind::Cast, - expr: Box::new(self.into_query_value().into_expr()), - data_type, - array: false, - format: None, - }) - } +impl_bind_orm_scalar_list!( + (A, B), + (A, B, C), + (A, B, C, D), + (A, B, C, D, E), + (A, B, C, D, E, F), + (A, B, C, D, E, F, G), + (A, B, C, D, E, F, G, H), +); - fn alias_value(self, alias: &str) -> ProjectedValue { - ProjectedValue { - item: SelectItem::ExprWithAlias { - expr: self.into_query_value().into_expr(), - alias: ident(alias), - }, - } - } +macro_rules! impl_quantified_subquery_methods { + ($($method:ident, $quantifier:ident, $negated:expr, $op:ident;)+) => { + $( + pub fn $method(self, build: F) -> Result + where + F: for<'scope, 'sub_bind, 'sub_parent> FnOnce( + &'scope mut OrmContext<'scope, 'sub_bind, 'sub_parent, 'arena, T, A>, + ) -> Result, + { + self.quantified_subquery( + MarkApplyQuantifier::$quantifier, + $negated, + expression::BinaryOperator::$op, + build, + ) + } + )+ + }; +} - fn in_subquery_expr(self, subquery: S) -> QueryExpr { - QueryExpr::from_expr(Expr::InSubquery { - expr: Box::new(self.into_query_value().into_expr()), - subquery: Box::new(subquery.into_subquery()), - negated: false, - }) - } +#[allow(clippy::type_complexity)] +struct ExprBindScopeHandle<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + binder: NonNull>, + arena: NonNull>, + _marker: PhantomData<(&'bind (), &'parent (), &'arena (), T, A, Rc<()>)>, +} - fn not_in_subquery_expr(self, subquery: S) -> QueryExpr { - QueryExpr::from_expr(Expr::InSubquery { - expr: Box::new(self.into_query_value().into_expr()), - subquery: Box::new(subquery.into_subquery()), - negated: true, - }) +impl<'bind, 'parent, 'arena, T, A> Clone for ExprBindScopeHandle<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn clone(&self) -> Self { + *self } } -impl Field { - #[doc(hidden)] - pub const fn new(table: &'static str, column: &'static str) -> Self { +impl<'bind, 'parent, 'arena, T, A> Copy for ExprBindScopeHandle<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ +} + +impl<'bind, 'parent, 'arena, T, A> ExprBindScopeHandle<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn new<'ctx>(scope: &ExprBindScope<'ctx, 'bind, 'parent, 'arena, T, A>) -> Self { Self { - table, - column, + binder: NonNull::new((&*scope.binder) as *const _ as *mut _).unwrap(), + arena: NonNull::new((&*scope.arena) as *const _ as *mut _).unwrap(), _marker: PhantomData, } } - fn value(self) -> QueryValue { - qualified_column_value(self.table, self.column) + fn wrap(self, expr: ScalarExpression) -> CtxExpression<'bind, 'parent, 'arena, T, A> { + CtxExpression { expr, scope: self } } - fn orm_field(self) -> &'static OrmField - where - M: Model, - { - M::fields() - .iter() - .find(|field| field.column == self.column) - .expect("ORM field metadata must exist for generated model fields") + #[allow(clippy::mut_from_ref)] + fn binder(&self) -> &mut Binder<'bind, 'parent, T, A> { + // SAFETY: ExprBindScopeHandle is created only from an active ExprBindScope + // during synchronous ORM binding. CtxExpression is !Send and !Sync, and + // all public ORM entry points immediately normalize expressions before + // leaving the bind/filter/project closure, so this pointer is never used + // after its owning binder scope has ended. + unsafe { &mut *self.binder.as_ptr() } } - /// Re-qualifies this field with a different source name such as a table alias. - /// - /// ```rust,ignore - /// let user = database - /// .from::() - /// .alias("u") - /// .eq(User::id().qualify("u"), 1) - /// .get()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn qualify(self, relation: &str) -> QueryValue { - qualified_column_value(relation, self.column) - } - - /// Builds `field + value`. - #[allow(clippy::should_implement_trait)] - pub fn add>(self, value: V) -> QueryValue { - ValueExpressionOps::add_expr(self, value) + #[allow(clippy::mut_from_ref)] + fn arena(&self) -> &mut PlanArena<'arena> { + // SAFETY: See binder(); the arena pointer has the same scope-bound + // lifetime and is accessed only through ORM expression binding methods. + unsafe { &mut *self.arena.as_ptr() } } - /// Builds `field - value`. - #[allow(clippy::should_implement_trait)] - pub fn sub>(self, value: V) -> QueryValue { - ValueExpressionOps::sub_expr(self, value) + fn binary( + self, + left: ScalarExpression, + op: expression::BinaryOperator, + right: ScalarExpression, + ) -> Result, DatabaseError> { + self.binder() + .bind_binary_op_expr(left, right, op, self.arena()) + .map(|expr| self.wrap(expr)) } - /// Builds `field * value`. - #[allow(clippy::should_implement_trait)] - pub fn mul>(self, value: V) -> QueryValue { - ValueExpressionOps::mul_expr(self, value) + fn unary( + self, + op: expression::UnaryOperator, + expr: ScalarExpression, + ) -> Result, DatabaseError> { + self.binder() + .bind_unary_op_expr(expr, op, self.arena()) + .map(|expr| self.wrap(expr)) } - /// Builds `field / value`. - #[allow(clippy::should_implement_trait)] - pub fn div>(self, value: V) -> QueryValue { - ValueExpressionOps::div_expr(self, value) + fn function( + self, + name: impl Into, + args: Vec, + ) -> Result, DatabaseError> { + self.binder() + .bind_function_call(name.into(), args, false, self.arena()) + .map(|expr| self.wrap(expr)) } - /// Builds `field % value`. - pub fn modulo>(self, value: V) -> QueryValue { - ValueExpressionOps::modulo_expr(self, value) + fn scalar_subquery( + self, + build: F, + ) -> Result, DatabaseError> + where + F: for<'scope, 'sub_bind, 'sub_parent> FnOnce( + &'scope mut OrmContext<'scope, 'sub_bind, 'sub_parent, 'arena, T, A>, + ) + -> Result, + { + self.binder() + .bind_scalar_subquery_plan(self.arena(), |binder, arena| { + let mut context = OrmContext { binder, arena }; + build(&mut context) + }) + .map(|expr| self.wrap(expr)) } - /// Builds unary `-field`. - #[allow(clippy::should_implement_trait)] - pub fn neg(self) -> QueryValue { - ValueExpressionOps::neg_expr(self) + fn exists_subquery( + self, + negated: bool, + build: F, + ) -> Result, DatabaseError> + where + F: for<'scope, 'sub_bind, 'sub_parent> FnOnce( + &'scope mut OrmContext<'scope, 'sub_bind, 'sub_parent, 'arena, T, A>, + ) + -> Result, + { + self.binder() + .bind_exists_subquery_plan(negated, self.arena(), |binder, arena| { + let mut context = OrmContext { binder, arena }; + build(&mut context) + }) + .map(|expr| self.wrap(expr)) } - /// Builds `field = value`. - pub fn eq>(self, value: V) -> QueryExpr { - ValueExpressionOps::eq_expr(self, value) + fn quantified_subquery( + self, + quantifier: MarkApplyQuantifier, + negated: bool, + left_expr: ScalarExpression, + compare_op: expression::BinaryOperator, + build: F, + ) -> Result, DatabaseError> + where + F: for<'scope, 'sub_bind, 'sub_parent> FnOnce( + &'scope mut OrmContext<'scope, 'sub_bind, 'sub_parent, 'arena, T, A>, + ) + -> Result, + { + self.binder() + .bind_quantified_subquery_plan( + quantifier, + negated, + left_expr, + compare_op, + self.arena(), + |binder, arena| { + let mut context = OrmContext { binder, arena }; + build(&mut context) + }, + ) + .map(|expr| self.wrap(expr)) } +} + +/// ORM expression bound to the current query scope. +/// +/// `CtxExpression` is a scope-bound ORM expression handle, not a reusable core +/// expression value. It exists so ORM code can use natural chained binding such +/// as `e.column(User::age())?.gte(18)?`. Convert it to a core +/// [`ScalarExpression`] only at ORM binder boundaries with [`Self::into_scalar`]. +/// +/// This type intentionally cannot be sent or shared across threads, and its +/// internal scope handle is private. +pub struct CtxExpression<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + expr: ScalarExpression, + scope: ExprBindScopeHandle<'bind, 'parent, 'arena, T, A>, +} - /// Builds `field <> value`. - pub fn ne>(self, value: V) -> QueryExpr { - ValueExpressionOps::ne_expr(self, value) +impl<'bind, 'parent, 'arena, T, A> CtxExpression<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub fn into_scalar(self) -> ScalarExpression { + self.expr } - /// Builds `field > value`. - pub fn gt>(self, value: V) -> QueryExpr { - ValueExpressionOps::gt_expr(self, value) + pub fn into_sort(self) -> SortField { + self.into_scalar().into() } - /// Builds `field >= value`. - pub fn gte>(self, value: V) -> QueryExpr { - ValueExpressionOps::gte_expr(self, value) + pub fn asc(self) -> SortField { + self.into_sort().asc() } - /// Builds `field < value`. - pub fn lt>(self, value: V) -> QueryExpr { - ValueExpressionOps::lt_expr(self, value) + pub fn desc(self) -> SortField { + self.into_sort().desc() } - /// Builds `field <= value`. - pub fn lte>(self, value: V) -> QueryExpr { - ValueExpressionOps::lte_expr(self, value) + pub fn nulls_first(self) -> SortField { + self.into_sort().nulls_first() } - /// Builds `field IS NULL`. - pub fn is_null(self) -> QueryExpr { - ValueExpressionOps::is_null_expr(self) + pub fn nulls_last(self) -> SortField { + self.into_sort().nulls_last() } - /// Builds `field IS NOT NULL`. - pub fn is_not_null(self) -> QueryExpr { - ValueExpressionOps::is_not_null_expr(self) + pub fn eq(self, right: R) -> Result { + self.scope.binary( + self.expr, + expression::BinaryOperator::Eq, + right.into_orm_scalar(), + ) } - /// Builds `field LIKE pattern`. - pub fn like>(self, pattern: V) -> QueryExpr { - ValueExpressionOps::like_expr(self, pattern) + pub fn ne(self, right: R) -> Result { + self.scope.binary( + self.expr, + expression::BinaryOperator::NotEq, + right.into_orm_scalar(), + ) } - /// Builds `field NOT LIKE pattern`. - pub fn not_like>(self, pattern: V) -> QueryExpr { - ValueExpressionOps::not_like_expr(self, pattern) + pub fn gt(self, right: R) -> Result { + self.scope.binary( + self.expr, + expression::BinaryOperator::Gt, + right.into_orm_scalar(), + ) } - /// Builds `field IN (...)`. - pub fn in_list(self, values: I) -> QueryExpr - where - I: IntoIterator, - V: Into, - { - ValueExpressionOps::in_list_expr(self, values) + pub fn gte(self, right: R) -> Result { + self.scope.binary( + self.expr, + expression::BinaryOperator::GtEq, + right.into_orm_scalar(), + ) } - /// Builds `field NOT IN (...)`. - pub fn not_in_list(self, values: I) -> QueryExpr - where - I: IntoIterator, - V: Into, - { - ValueExpressionOps::not_in_list_expr(self, values) + pub fn lt(self, right: R) -> Result { + self.scope.binary( + self.expr, + expression::BinaryOperator::Lt, + right.into_orm_scalar(), + ) } - /// Builds `field BETWEEN low AND high`. - pub fn between, H: Into>(self, low: L, high: H) -> QueryExpr { - ValueExpressionOps::between_expr(self, low, high) + pub fn lte(self, right: R) -> Result { + self.scope.binary( + self.expr, + expression::BinaryOperator::LtEq, + right.into_orm_scalar(), + ) } - /// Builds `field NOT BETWEEN low AND high`. - pub fn not_between, H: Into>( - self, - low: L, - high: H, - ) -> QueryExpr { - ValueExpressionOps::not_between_expr(self, low, high) + pub fn like(self, right: R) -> Result { + self.scope.binary( + self.expr, + expression::BinaryOperator::Like(None), + right.into_orm_scalar(), + ) } - /// Casts this field using a SQL type string such as `"BIGINT"`. - pub fn cast(self, data_type: &str) -> Result { - ValueExpressionOps::cast_value(self, data_type) + pub fn not_like(self, right: R) -> Result { + self.scope.binary( + self.expr, + expression::BinaryOperator::NotLike(None), + right.into_orm_scalar(), + ) } - /// Casts this field using an explicit SQL AST data type. - pub fn cast_to(self, data_type: DataType) -> QueryValue { - ValueExpressionOps::cast_to_value(self, data_type) + pub fn and(self, right: R) -> Result { + self.scope.binary( + self.expr, + expression::BinaryOperator::And, + right.into_orm_scalar(), + ) } - /// Aliases this field in the select list. - pub fn alias(self, alias: &str) -> ProjectedValue { - ValueExpressionOps::alias_value(self, alias) + pub fn or(self, right: R) -> Result { + self.scope.binary( + self.expr, + expression::BinaryOperator::Or, + right.into_orm_scalar(), + ) } - /// Builds `field IN (subquery)`. - pub fn in_subquery(self, subquery: S) -> QueryExpr { - ValueExpressionOps::in_subquery_expr(self, subquery) + #[allow(clippy::should_implement_trait)] + pub fn not(self) -> Result { + self.scope.unary(expression::UnaryOperator::Not, self.expr) } - /// Builds `field NOT IN (subquery)`. - pub fn not_in_subquery(self, subquery: S) -> QueryExpr { - ValueExpressionOps::not_in_subquery_expr(self, subquery) + pub fn is_null(self) -> Self { + let scope = self.scope; + let expr = ScalarExpression::IsNull { + negated: false, + expr: Box::new(self.expr), + }; + scope.wrap(expr) } -} -#[derive(Debug, Clone, PartialEq)] -/// A lightweight ORM expression wrapper for value-producing SQL AST nodes. -/// -/// `QueryValue` is the common currency for computed ORM expressions such as -/// functions, aggregates, `CASE` expressions, casts, and subqueries. -/// -/// It also supports the same comparison and predicate-style composition used by -/// [`Field`], including helpers such as `eq`, `gt`, `like`, `in_list`, and -/// `between`. -/// -/// ```rust,ignore -/// let adults = database -/// .from::() -/// .filter(User::age().gt(18)) -/// .fetch()?; -/// # Ok::<(), kite_sql::errors::DatabaseError>(()) -/// ``` -pub struct QueryValue { - expr: Expr, -} + pub fn is_not_null(self) -> Self { + let scope = self.scope; + let expr = ScalarExpression::IsNull { + negated: true, + expr: Box::new(self.expr), + }; + scope.wrap(expr) + } -#[derive(Debug, Clone, PartialEq)] -/// A projected ORM expression, optionally carrying a select-list alias. -/// -/// This is typically produced by calling `.alias(...)` on a [`Field`] or -/// [`QueryValue`], and then passed into `project_value(...)` or -/// `project_tuple(...)`. -#[doc(hidden)] -pub struct ProjectedValue { - item: SelectItem, -} + pub fn in_list(self, values: I) -> Result + where + I: IntoIterator, + E: IntoOrmScalarExpression, + { + let scope = self.scope; + let expr = ScalarExpression::In { + negated: false, + expr: Box::new(self.expr), + args: values + .into_iter() + .map(IntoOrmScalarExpression::into_orm_scalar) + .collect(), + }; + Ok(scope.wrap(expr)) + } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum CompareOp { - Eq, - Ne, - Gt, - Gte, - Lt, - Lte, -} + pub fn not_in_list(self, values: I) -> Result + where + I: IntoIterator, + E: IntoOrmScalarExpression, + { + let scope = self.scope; + let expr = ScalarExpression::In { + negated: true, + expr: Box::new(self.expr), + args: values + .into_iter() + .map(IntoOrmScalarExpression::into_orm_scalar) + .collect(), + }; + Ok(scope.wrap(expr)) + } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum QuantifiedSubquery { - Any, - Some, - All, -} + pub fn between(self, low: L, high: H) -> Result + where + L: IntoOrmScalarExpression, + H: IntoOrmScalarExpression, + { + let scope = self.scope; + let expr = ScalarExpression::Between { + negated: false, + expr: Box::new(self.expr), + left_expr: Box::new(low.into_orm_scalar()), + right_expr: Box::new(high.into_orm_scalar()), + }; + Ok(scope.wrap(expr)) + } -macro_rules! quantified_value_methods { - ($(($method:ident, $op:ident, $quantifier:ident, $symbol:literal, $keyword:literal)),+ $(,)?) => { - $( - #[doc = concat!("Builds `expr ", $symbol, " ", $keyword, " (subquery)`.")] - pub fn $method(self, subquery: S) -> QueryExpr { - ValueExpressionOps::quantified_subquery_expr( - self, - CompareOp::$op, - QuantifiedSubquery::$quantifier, - subquery, - ) - } - )+ - }; -} + pub fn not_between(self, low: L, high: H) -> Result + where + L: IntoOrmScalarExpression, + H: IntoOrmScalarExpression, + { + let scope = self.scope; + let expr = ScalarExpression::Between { + negated: true, + expr: Box::new(self.expr), + left_expr: Box::new(low.into_orm_scalar()), + right_expr: Box::new(high.into_orm_scalar()), + }; + Ok(scope.wrap(expr)) + } -macro_rules! quantified_methods { - () => { - quantified_value_methods!( - (eq_any, Eq, Any, "=", "ANY"), - (ne_any, Ne, Any, "<>", "ANY"), - (gt_any, Gt, Any, ">", "ANY"), - (gte_any, Gte, Any, ">=", "ANY"), - (lt_any, Lt, Any, "<", "ANY"), - (lte_any, Lte, Any, "<=", "ANY"), - (eq_some, Eq, Some, "=", "SOME"), - (ne_some, Ne, Some, "<>", "SOME"), - (gt_some, Gt, Some, ">", "SOME"), - (gte_some, Gte, Some, ">=", "SOME"), - (lt_some, Lt, Some, "<", "SOME"), - (lte_some, Lte, Some, "<=", "SOME"), - (eq_all, Eq, All, "=", "ALL"), - (ne_all, Ne, All, "<>", "ALL"), - (gt_all, Gt, All, ">", "ALL"), - (gte_all, Gte, All, ">=", "ALL"), - (lt_all, Lt, All, "<", "ALL"), - (lte_all, Lte, All, "<=", "ALL"), - ); - }; -} + pub fn alias(self, alias: impl Into) -> Self { + let scope = self.scope; + let alias = alias.into(); + scope + .binder() + .context + .add_alias(None, alias.clone(), self.expr.clone()); + let expr = ScalarExpression::Alias { + expr: Box::new(self.expr), + alias: AliasType::Name(alias), + }; + scope.wrap(expr) + } -impl Field { - quantified_methods!(); + pub fn cast(self, ty: LogicalType) -> Result { + let scope = self.scope; + ScalarExpression::type_cast(self.expr, Cow::Owned(ty), scope.arena()) + .map(|expr| scope.wrap(expr)) + } + + pub fn function( + self, + name: impl Into, + args: impl IntoIterator, + ) -> Result + where + E: IntoOrmScalarExpression, + { + let mut args = args + .into_iter() + .map(IntoOrmScalarExpression::into_orm_scalar) + .collect::>(); + args.insert(0, self.expr); + self.scope.function(name, args) + } + + fn quantified_subquery( + self, + quantifier: MarkApplyQuantifier, + negated: bool, + compare_op: expression::BinaryOperator, + build: F, + ) -> Result + where + F: for<'scope, 'sub_bind, 'sub_parent> FnOnce( + &'scope mut OrmContext<'scope, 'sub_bind, 'sub_parent, 'arena, T, A>, + ) + -> Result, + { + self.scope + .quantified_subquery(quantifier, negated, self.expr, compare_op, build) + } + + impl_quantified_subquery_methods! { + eq_any, Any, false, Eq; + eq_all, All, false, Eq; + gt_any, Any, false, Gt; + gt_all, All, false, Gt; + gte_any, Any, false, GtEq; + gte_all, All, false, GtEq; + lt_any, Any, false, Lt; + lt_all, All, false, Lt; + lte_any, Any, false, LtEq; + lte_all, All, false, LtEq; + in_subquery, Any, false, Eq; + not_in_subquery, Any, true, Eq; + } } -#[derive(Debug, Clone, PartialEq)] -/// A lightweight ORM expression wrapper for predicate-oriented SQL AST nodes. -/// -/// `QueryExpr` is used for `WHERE` and `HAVING` clauses, as well as boolean -/// composition such as `and`, `or`, and `not`. -pub struct QueryExpr { - expr: Expr, +impl<'bind, 'parent, 'arena, T, A> fmt::Debug for CtxExpression<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.expr.fmt(f) + } } -/// Builds a scalar function call expression. -/// -/// This is commonly used for UDFs registered on the database. -/// -/// ```rust,ignore -/// let expr = kite_sql::orm::func("add_one", [User::id()]); -/// let row = database.from::().eq(expr, 2).get()?; -/// # Ok::<(), kite_sql::errors::DatabaseError>(()) -/// ``` -pub fn func(name: N, args: I) -> QueryValue +impl<'bind, 'parent, 'arena, T, A> Clone for CtxExpression<'bind, 'parent, 'arena, T, A> where - N: Into, - I: IntoIterator, - V: Into, + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, { - QueryValue::function(name, args) + fn clone(&self) -> Self { + Self { + expr: self.expr.clone(), + scope: self.scope, + } + } } -/// Builds `count(expr)`. -/// -/// ```rust,ignore -/// let grouped = database -/// .from::() -/// .project_tuple((EventLog::category(), kite_sql::orm::count(EventLog::id()))) -/// .group_by(EventLog::category()) -/// .fetch::<(String, i32)>()?; -/// # Ok::<(), kite_sql::errors::DatabaseError>(()) -/// ``` -pub fn count>(value: V) -> QueryValue { - QueryValue::aggregate("count", [value.into()]) +impl<'bind, 'parent, 'arena, T, A> PartialEq for CtxExpression<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn eq(&self, other: &Self) -> bool { + self.expr == other.expr + } } -/// Builds `count(*)`. -/// -/// ```rust,ignore -/// let total = database -/// .from::() -/// .project_value(kite_sql::orm::count_all().alias("total_users")) -/// .get::()?; -/// # Ok::<(), kite_sql::errors::DatabaseError>(()) -/// ``` -pub fn count_all() -> QueryValue { - QueryValue::aggregate_all("count") +impl<'bind, 'parent, 'arena, T, A> Eq for CtxExpression<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ } -/// Builds `sum(expr)`. -/// -/// ```rust,ignore -/// let totals = database -/// .from::() -/// .project_value(kite_sql::orm::sum(Order::amount())) -/// .get::()?; -/// # Ok::<(), kite_sql::errors::DatabaseError>(()) -/// ``` -pub fn sum>(value: V) -> QueryValue { - QueryValue::aggregate("sum", [value.into()]) +impl<'bind, 'parent, 'arena, T, A> Hash for CtxExpression<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn hash(&self, state: &mut H) { + self.expr.hash(state); + } } -/// Builds `avg(expr)`. -pub fn avg>(value: V) -> QueryValue { - QueryValue::aggregate("avg", [value.into()]) +impl<'bind, 'parent, 'arena, T, A> IntoOrmScalarExpression + for CtxExpression<'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn into_orm_scalar(self) -> ScalarExpression { + self.into_scalar() + } } -/// Builds `min(expr)`. -pub fn min>(value: V) -> QueryValue { - QueryValue::aggregate("min", [value.into()]) +fn bind_orm_context(executor: E, build: F) -> Result +where + E: BindSource, + F: for<'ctx, 'bind, 'parent, 'arena> FnOnce( + &'ctx mut OrmContext< + 'ctx, + 'bind, + 'parent, + 'arena, + E::Transaction, + &'static [(&'static str, DataValue)], + >, + ) -> Result, +{ + static EMPTY_BIND_PARAMS: &[(&str, DataValue)] = &[]; + executor.execute(EMPTY_BIND_PARAMS, |binder, arena| { + let mut context = OrmContext { binder, arena }; + build(&mut context) + }) } -/// Builds `max(expr)`. -pub fn max>(value: V) -> QueryValue { - QueryValue::aggregate("max", [value.into()]) +fn explain_orm_context(executor: E, build: F) -> Result +where + E: BindSource, + F: for<'ctx, 'bind, 'parent, 'arena> FnOnce( + &'ctx mut OrmContext< + 'ctx, + 'bind, + 'parent, + 'arena, + E::Transaction, + &'static [(&'static str, DataValue)], + >, + ) -> Result, +{ + static EMPTY_BIND_PARAMS: &[(&str, DataValue)] = &[]; + executor.explain(EMPTY_BIND_PARAMS, |binder, arena| { + let mut context = OrmContext { binder, arena }; + build(&mut context) + }) } -/// Builds a searched `CASE WHEN ... THEN ... ELSE ... END` expression. +/// Binder-backed ORM query context. /// -/// ```rust,ignore -/// let bucket = kite_sql::orm::case_when( -/// [(User::age().is_null(), "unknown"), (User::age().lt(20), "minor")], -/// "adult", -/// ); -/// # let _ = bucket; -/// ``` -pub fn case_when(conditions: I, else_result: E) -> QueryValue +/// This context is created by [`Database::bind`] or [`DBTransaction::bind`]. Query construction inside +/// the closure binds directly into [`ScalarExpression`] and [`LogicalPlan`] +/// values; it does not build an ORM expression tree first. +pub struct OrmContext<'ctx, 'bind, 'parent, 'arena, T, A> where - I: IntoIterator, - C: Into, - R: Into, - E: Into, + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, { - QueryValue::searched_case(conditions, else_result) + binder: &'ctx mut Binder<'bind, 'parent, T, A>, + arena: &'ctx mut PlanArena<'arena>, } -/// Builds a simple `CASE value WHEN ... THEN ... ELSE ... END` expression. -/// -/// ```rust,ignore -/// let label = kite_sql::orm::case_value( -/// User::age(), -/// [(18, "adult"), (30, "senior")], -/// "other", -/// ); -/// let rows = database -/// .from::() -/// .project_tuple((User::id(), label.alias("age_label"))) -/// .fetch::<(i32, String)>()?; -/// # Ok::<(), kite_sql::errors::DatabaseError>(()) -/// ``` -pub fn case_value(operand: O, conditions: I, else_result: E) -> QueryValue +/// Narrow expression binding scope borrowed from an [`OrmContext`]. +pub struct ExprBindScope<'ctx, 'bind, 'parent, 'arena, T, A> where - O: Into, - I: IntoIterator, - W: Into, - R: Into, - E: Into, + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, { - QueryValue::simple_case(operand, conditions, else_result) + binder: &'ctx mut Binder<'bind, 'parent, T, A>, + arena: &'ctx mut PlanArena<'arena>, } -impl QueryExpr { - fn from_expr(expr: Expr) -> QueryExpr { - Self { expr } - } +pub struct UpdateBindScope<'ctx, 'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + binder: &'ctx mut Binder<'bind, 'parent, T, A>, + arena: &'ctx mut PlanArena<'arena>, + source_name: String, + value_exprs: Vec<(ColumnRef, ScalarExpression)>, +} - fn into_expr(self) -> Expr { - self.expr +impl<'ctx, 'bind, 'parent, 'arena, T, A> OrmContext<'ctx, 'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub fn from<'scope, M: Model>( + &'scope mut self, + ) -> Result, DatabaseError> { + self.from_source(QuerySource::model::(), false) } - /// Combines two predicates with `AND`. - /// - /// ```rust,ignore - /// let expr = User::age().gte(18).and(User::name().like("A%")); - /// let users = database.from::().filter(expr).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn and(self, rhs: QueryExpr) -> QueryExpr { - QueryExpr::from_expr(Expr::BinaryOp { - left: Box::new(nested_expr(self.into_expr())), - op: SqlBinaryOperator::And, - right: Box::new(nested_expr(rhs.into_expr())), - }) + pub fn from_as<'scope, M: Model>( + &'scope mut self, + alias: impl Into, + ) -> Result, DatabaseError> { + self.from_source(QuerySource::model::().with_alias(alias), false) } - /// Combines two predicates with `OR`. - /// - /// ```rust,ignore - /// let expr = User::name().like("A%").or(User::name().like("B%")); - /// let users = database.from::().filter(expr).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn or(self, rhs: QueryExpr) -> QueryExpr { - QueryExpr::from_expr(Expr::BinaryOp { - left: Box::new(nested_expr(self.into_expr())), - op: SqlBinaryOperator::Or, - right: Box::new(nested_expr(rhs.into_expr())), - }) + pub fn mutate<'scope, M: Model>( + &'scope mut self, + ) -> Result, DatabaseError> { + self.from_source(QuerySource::model::(), true) } - /// Negates a predicate with `NOT`. - /// - /// ```rust,ignore - /// let expr = User::name().like("A%").not(); - /// let users = database.from::().filter(expr).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - #[allow(clippy::should_implement_trait)] - pub fn not(self) -> QueryExpr { - QueryExpr::from_expr(Expr::UnaryOp { - op: sqlparser::ast::UnaryOperator::Not, - expr: Box::new(nested_expr(self.into_expr())), - }) + pub fn mutate_as<'scope, M: Model>( + &'scope mut self, + alias: impl Into, + ) -> Result, DatabaseError> { + self.from_source(QuerySource::model::().with_alias(alias), true) } - /// Builds an `EXISTS (subquery)` predicate. - /// - /// ```rust,ignore - /// let expr = kite_sql::orm::QueryExpr::exists( - /// database.from::().project_value(User::id()).eq(User::id(), 1), - /// ); - /// let found = database.from::().filter(expr).exists()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn exists(subquery: S) -> QueryExpr { - QueryExpr::from_expr(Expr::Exists { - subquery: Box::new(subquery.into_subquery()), - negated: false, - }) + #[allow(clippy::wrong_self_convention)] + fn from_source<'scope, M: Model>( + &'scope mut self, + source: QuerySource, + mutation_source: bool, + ) -> Result, DatabaseError> { + if mutation_source { + self.binder.with_pk(source.table_name.as_str().into()); + } + let plan = bind_orm_source(self.binder, source.clone(), None, self.arena); + if mutation_source { + self.binder.clear_with_pk(); + } + let plan = plan?; + self.binder + .build_plan(self.arena) + .from_plan(plan) + .map(|from| from.typed()) } - /// Builds a `NOT EXISTS (subquery)` predicate. - /// - /// ```rust,ignore - /// let expr = kite_sql::orm::QueryExpr::not_exists( - /// database.from::().project_value(Order::id()).eq(Order::user_id(), User::id()), - /// ); - /// let users = database.from::().filter(expr).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn not_exists(subquery: S) -> QueryExpr { - QueryExpr::from_expr(Expr::Exists { - subquery: Box::new(subquery.into_subquery()), - negated: true, - }) + fn set_operation( + &mut self, + op: SetOperatorKind, + all: bool, + left: L, + right: R, + ) -> Result + where + L: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, + R: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, + { + let left_plan = self.child_plan(left)?; + let right_plan = self.child_plan(right)?; + self.binder + .bind_set_operation_plans(op, all, left_plan, right_plan, self.arena) } -} - -#[derive(Debug, Clone, PartialEq)] -struct SortExpr { - value: QueryValue, - desc: bool, - nulls_first: Option, -} -impl SortExpr { - fn new(value: QueryValue, desc: bool) -> Self { - Self { - value, - desc, - nulls_first: None, + fn child_plan(&mut self, build: F) -> Result + where + F: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, + { + let mut child_binder = Binder::new( + self.binder.context.fork(), + self.binder.args, + self.binder.parent, + ); + let plan = { + let mut context = OrmContext { + binder: &mut child_binder, + arena: self.arena, + }; + build(&mut context)? + }; + if child_binder.context.has_outer_refs() { + self.binder.context.mark_outer_ref(); } + Ok(plan) } - fn with_nulls(mut self, nulls_first: bool) -> Self { - self.nulls_first = Some(nulls_first); - self + pub fn union( + &mut self, + all: bool, + left: L, + right: R, + ) -> Result + where + L: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, + R: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, + { + self.set_operation(SetOperatorKind::Union, all, left, right) } - fn into_ast(self) -> OrderByExpr { - OrderByExpr { - expr: self.value.into_expr(), - options: OrderByOptions { - asc: Some(!self.desc), - nulls_first: self.nulls_first, - }, - with_fill: None, - } + pub fn except( + &mut self, + all: bool, + left: L, + right: R, + ) -> Result + where + L: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, + R: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, + { + self.set_operation(SetOperatorKind::Except, all, left, right) } -} -impl QueryValue { - fn from_expr(expr: Expr) -> Self { - Self { expr } + pub fn intersect( + &mut self, + all: bool, + left: L, + right: R, + ) -> Result + where + L: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, + R: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, + { + self.set_operation(SetOperatorKind::Intersect, all, left, right) } - fn into_expr(self) -> Expr { - self.expr + pub fn insert_select( + &mut self, + columns: C, + build: F, + ) -> Result + where + M: Model, + C: IntoIterator, + C::Item: Into, + F: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, + { + self.insert_select_inner::(columns, false, build) } - /// Builds a scalar function call. - /// - /// ```rust,ignore - /// let expr = kite_sql::orm::QueryValue::function("add_one", [User::id()]); - /// let row = database.from::().eq(expr, 2).get()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn function(name: N, args: I) -> Self + pub fn overwrite_select( + &mut self, + columns: C, + build: F, + ) -> Result where - N: Into, - I: IntoIterator, - V: Into, - { - Self::function_with_args( - name, - args.into_iter() - .map(Into::into) - .map(QueryValue::into_expr) - .map(FunctionArgExpr::Expr), + M: Model, + C: IntoIterator, + C::Item: Into, + F: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, ) + -> Result, + { + self.insert_select_inner::(columns, true, build) } - /// Builds an aggregate function call such as `sum(expr)` or `count(expr)`. - /// - /// ```rust,ignore - /// let total = kite_sql::orm::QueryValue::aggregate("sum", [Order::amount()]); - /// let row = database.from::().project_value(total).get::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn aggregate(name: N, args: I) -> Self + fn insert_select_inner( + &mut self, + columns: C, + overwrite: bool, + build: F, + ) -> Result where - N: Into, - I: IntoIterator, - V: Into, + M: Model, + C: IntoIterator, + C::Item: Into, + F: for<'scope, 'child_bind, 'child_parent> FnOnce( + &'scope mut OrmContext<'scope, 'child_bind, 'child_parent, 'arena, T, A>, + ) + -> Result, { - Self::function(name, args) + let input_plan = self.child_plan(build)?; + bind_orm_insert_plan( + self.binder, + M::table_name(), + columns.into_iter().map(Into::into).collect(), + input_plan, + overwrite, + self.arena, + ) } - /// Builds an aggregate function call that uses `*`, such as `count(*)`. - /// - /// ```rust,ignore - /// let total = kite_sql::orm::QueryValue::aggregate_all("count").alias("total"); - /// let row = database.from::().project_value(total).get::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn aggregate_all(name: impl Into) -> Self { - Self::function_with_args(name, [FunctionArgExpr::Wildcard]) + pub fn truncate(&mut self) -> Result { + self.binder.bind_truncate(M::table_name().into()) } +} - /// Assigns a select-list alias to this value expression. - /// - /// ```rust,ignore - /// let rows = database - /// .from::() - /// .project_value(kite_sql::orm::sum(Order::amount()).alias("total_amount")) - /// .raw()?; - /// # rows.done()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn alias(self, alias: &str) -> ProjectedValue { - ProjectedValue { - item: SelectItem::ExprWithAlias { - expr: self.into_expr(), - alias: ident(alias), - }, - } +impl<'ctx, 'bind, 'parent, 'arena, T, A> ExprBindScope<'ctx, 'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + fn handle(&self) -> ExprBindScopeHandle<'bind, 'parent, 'arena, T, A> { + ExprBindScopeHandle::new(self) + } + + fn wrap(&self, expr: ScalarExpression) -> CtxExpression<'bind, 'parent, 'arena, T, A> { + self.handle().wrap(expr) + } + + pub fn column( + &self, + field: Field, + ) -> Result, DatabaseError> { + let scope = self.handle(); + let expr = scope.binder().bind_column_ref_by_name( + Some(field.table), + field.column, + None, + scope.arena(), + )?; + Ok(scope.wrap(expr)) + } + + pub fn qualified_column( + &self, + relation: &str, + field: Field, + ) -> Result, DatabaseError> { + let scope = self.handle(); + let expr = scope.binder().bind_column_ref_by_name( + Some(relation), + field.column, + None, + scope.arena(), + )?; + Ok(scope.wrap(expr)) } - /// Builds a searched `CASE WHEN ... THEN ... ELSE ... END` expression. - /// - /// ```rust,ignore - /// let label = kite_sql::orm::QueryValue::searched_case( - /// [(User::age().is_null(), "unknown"), (User::age().lt(18), "minor")], - /// "adult", - /// ); - /// let rows = database - /// .from::() - /// .project_tuple((User::name(), label.alias("age_group"))) - /// .fetch::<(String, String)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn searched_case(conditions: I, else_result: E) -> Self + #[doc(hidden)] + pub fn column_ref( + &self, + relation: &str, + column: &str, + ) -> Result, DatabaseError> { + let scope = self.handle(); + let expr = + scope + .binder() + .bind_column_ref_by_name(Some(relation), column, None, scope.arena())?; + Ok(scope.wrap(expr)) + } + + pub fn value(&self, value: V) -> CtxExpression<'bind, 'parent, 'arena, T, A> { + self.wrap(ScalarExpression::Constant(value.to_data_value())) + } + + pub fn data_value(&self, value: DataValue) -> CtxExpression<'bind, 'parent, 'arena, T, A> { + self.wrap(ScalarExpression::Constant(value)) + } + + pub fn alias( + &self, + expr: impl IntoOrmScalarExpression, + alias: impl Into, + ) -> CtxExpression<'bind, 'parent, 'arena, T, A> { + self.wrap(expr.into_orm_scalar()).alias(alias) + } + + pub fn cast( + &self, + expr: impl IntoOrmScalarExpression, + ty: LogicalType, + ) -> Result, DatabaseError> { + let scope = self.handle(); + let expr = + ScalarExpression::type_cast(expr.into_orm_scalar(), Cow::Owned(ty), scope.arena())?; + Ok(scope.wrap(expr)) + } + + pub fn unary( + &self, + op: expression::UnaryOperator, + expr: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.handle().unary(op, expr.into_orm_scalar()) + } + + pub fn binary( + &self, + left: impl IntoOrmScalarExpression, + op: expression::BinaryOperator, + right: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.handle() + .binary(left.into_orm_scalar(), op, right.into_orm_scalar()) + } + + pub fn eq( + &self, + left: impl IntoOrmScalarExpression, + right: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.binary(left, expression::BinaryOperator::Eq, right) + } + + pub fn ne( + &self, + left: impl IntoOrmScalarExpression, + right: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.binary(left, expression::BinaryOperator::NotEq, right) + } + + pub fn gt( + &self, + left: impl IntoOrmScalarExpression, + right: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.binary(left, expression::BinaryOperator::Gt, right) + } + + pub fn gte( + &self, + left: impl IntoOrmScalarExpression, + right: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.binary(left, expression::BinaryOperator::GtEq, right) + } + + pub fn lt( + &self, + left: impl IntoOrmScalarExpression, + right: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.binary(left, expression::BinaryOperator::Lt, right) + } + + pub fn lte( + &self, + left: impl IntoOrmScalarExpression, + right: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.binary(left, expression::BinaryOperator::LtEq, right) + } + + pub fn and( + &self, + left: impl IntoOrmScalarExpression, + right: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.binary(left, expression::BinaryOperator::And, right) + } + + pub fn or( + &self, + left: impl IntoOrmScalarExpression, + right: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.binary(left, expression::BinaryOperator::Or, right) + } + + pub fn is_null( + &self, + expr: impl IntoOrmScalarExpression, + ) -> CtxExpression<'bind, 'parent, 'arena, T, A> { + self.wrap(expr.into_orm_scalar()).is_null() + } + + pub fn is_not_null( + &self, + expr: impl IntoOrmScalarExpression, + ) -> CtxExpression<'bind, 'parent, 'arena, T, A> { + self.wrap(expr.into_orm_scalar()).is_not_null() + } + + pub fn in_list( + &self, + expr: impl IntoOrmScalarExpression, + args: I, + ) -> CtxExpression<'bind, 'parent, 'arena, T, A> where - I: IntoIterator, - C: Into, - R: Into, - E: Into, + I: IntoIterator, + E: IntoOrmScalarExpression, { - Self::from_expr(Expr::Case { - case_token: AttachedToken::empty(), - end_token: AttachedToken::empty(), - operand: None, - conditions: conditions + let expr = ScalarExpression::In { + negated: false, + expr: Box::new(expr.into_orm_scalar()), + args: args .into_iter() - .map(|(condition, result)| CaseWhen { - condition: condition.into().into_expr(), - result: result.into().into_expr(), - }) + .map(IntoOrmScalarExpression::into_orm_scalar) .collect(), - else_result: Some(Box::new(else_result.into().into_expr())), - }) + }; + self.wrap(expr) } - /// Builds a simple `CASE value WHEN ... THEN ... ELSE ... END` expression. - /// - /// ```rust,ignore - /// let label = kite_sql::orm::QueryValue::simple_case( - /// User::status(), - /// [("active", "enabled"), ("disabled", "blocked")], - /// "other", - /// ); - /// let rows = database - /// .from::() - /// .project_tuple((User::id(), label.alias("status_label"))) - /// .fetch::<(i32, String)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn simple_case(operand: O, conditions: I, else_result: E) -> Self + pub fn not_in_list( + &self, + expr: impl IntoOrmScalarExpression, + args: I, + ) -> CtxExpression<'bind, 'parent, 'arena, T, A> where - O: Into, - I: IntoIterator, - W: Into, - R: Into, - E: Into, + I: IntoIterator, + E: IntoOrmScalarExpression, { - Self::from_expr(Expr::Case { - case_token: AttachedToken::empty(), - end_token: AttachedToken::empty(), - operand: Some(Box::new(operand.into().into_expr())), - conditions: conditions + let expr = ScalarExpression::In { + negated: true, + expr: Box::new(expr.into_orm_scalar()), + args: args .into_iter() - .map(|(condition, result)| CaseWhen { - condition: condition.into().into_expr(), - result: result.into().into_expr(), - }) + .map(IntoOrmScalarExpression::into_orm_scalar) .collect(), - else_result: Some(Box::new(else_result.into().into_expr())), - }) - } - - /// Builds `expr + value`. - #[allow(clippy::should_implement_trait)] - pub fn add>(self, value: V) -> QueryValue { - ValueExpressionOps::add_expr(self, value) + }; + self.wrap(expr) } - /// Builds `expr - value`. - #[allow(clippy::should_implement_trait)] - pub fn sub>(self, value: V) -> QueryValue { - ValueExpressionOps::sub_expr(self, value) + pub fn between( + &self, + expr: impl IntoOrmScalarExpression, + low: impl IntoOrmScalarExpression, + high: impl IntoOrmScalarExpression, + ) -> CtxExpression<'bind, 'parent, 'arena, T, A> { + let expr = ScalarExpression::Between { + negated: false, + expr: Box::new(expr.into_orm_scalar()), + left_expr: Box::new(low.into_orm_scalar()), + right_expr: Box::new(high.into_orm_scalar()), + }; + self.wrap(expr) } - /// Builds `expr * value`. - #[allow(clippy::should_implement_trait)] - pub fn mul>(self, value: V) -> QueryValue { - ValueExpressionOps::mul_expr(self, value) + pub fn not_between( + &self, + expr: impl IntoOrmScalarExpression, + low: impl IntoOrmScalarExpression, + high: impl IntoOrmScalarExpression, + ) -> CtxExpression<'bind, 'parent, 'arena, T, A> { + let expr = ScalarExpression::Between { + negated: true, + expr: Box::new(expr.into_orm_scalar()), + left_expr: Box::new(low.into_orm_scalar()), + right_expr: Box::new(high.into_orm_scalar()), + }; + self.wrap(expr) } - /// Builds `expr / value`. - #[allow(clippy::should_implement_trait)] - pub fn div>(self, value: V) -> QueryValue { - ValueExpressionOps::div_expr(self, value) + pub fn not( + &self, + expr: impl IntoOrmScalarExpression, + ) -> Result, DatabaseError> { + self.unary(expression::UnaryOperator::Not, expr) } - /// Builds `expr % value`. - pub fn modulo>(self, value: V) -> QueryValue { - ValueExpressionOps::modulo_expr(self, value) + pub fn function( + &self, + name: impl Into, + args: impl IntoIterator, + ) -> Result, DatabaseError> + where + E: IntoOrmScalarExpression, + { + let args = args + .into_iter() + .map(IntoOrmScalarExpression::into_orm_scalar) + .collect(); + self.handle().function(name, args) } - /// Builds unary `-expr`. - #[allow(clippy::should_implement_trait)] - pub fn neg(self) -> QueryValue { - ValueExpressionOps::neg_expr(self) + pub fn aggregate( + &self, + name: impl Into, + args: impl IntoIterator, + ) -> Result, DatabaseError> + where + E: IntoOrmScalarExpression, + { + self.function(name, args) } - /// Builds `expr = value`. - pub fn eq>(self, value: V) -> QueryExpr { - ValueExpressionOps::eq_expr(self, value) + pub fn count_all(&self) -> Result, DatabaseError> { + self.function( + "count", + vec![Binder::<'bind, 'parent, T, A>::wildcard_expr()], + ) } - /// Builds `expr <> value`. - pub fn ne>(self, value: V) -> QueryExpr { - ValueExpressionOps::ne_expr(self, value) - } - - /// Builds `expr > value`. - pub fn gt>(self, value: V) -> QueryExpr { - ValueExpressionOps::gt_expr(self, value) - } - - /// Builds `expr >= value`. - pub fn gte>(self, value: V) -> QueryExpr { - ValueExpressionOps::gte_expr(self, value) - } - - /// Builds `expr < value`. - pub fn lt>(self, value: V) -> QueryExpr { - ValueExpressionOps::lt_expr(self, value) - } - - /// Builds `expr <= value`. - pub fn lte>(self, value: V) -> QueryExpr { - ValueExpressionOps::lte_expr(self, value) - } - - quantified_methods!(); - - /// Builds `expr IS NULL`. - pub fn is_null(self) -> QueryExpr { - ValueExpressionOps::is_null_expr(self) - } - - /// Builds `expr IS NOT NULL`. - pub fn is_not_null(self) -> QueryExpr { - ValueExpressionOps::is_not_null_expr(self) - } - - /// Builds `expr LIKE pattern`. - pub fn like>(self, pattern: V) -> QueryExpr { - ValueExpressionOps::like_expr(self, pattern) - } - - /// Builds `expr NOT LIKE pattern`. - pub fn not_like>(self, pattern: V) -> QueryExpr { - ValueExpressionOps::not_like_expr(self, pattern) - } - - /// Builds `expr IN (...)`. - pub fn in_list(self, values: I) -> QueryExpr - where - I: IntoIterator, - V: Into, - { - ValueExpressionOps::in_list_expr(self, values) - } - - /// Builds `expr NOT IN (...)`. - pub fn not_in_list(self, values: I) -> QueryExpr - where - I: IntoIterator, - V: Into, - { - ValueExpressionOps::not_in_list_expr(self, values) - } - - /// Builds `expr BETWEEN low AND high`. - pub fn between, H: Into>(self, low: L, high: H) -> QueryExpr { - ValueExpressionOps::between_expr(self, low, high) - } - - /// Builds `expr NOT BETWEEN low AND high`. - pub fn not_between, H: Into>( - self, - low: L, - high: H, - ) -> QueryExpr { - ValueExpressionOps::not_between_expr(self, low, high) - } - - /// Casts this expression using a SQL type string such as `"BIGINT"`. - /// - /// ```rust,ignore - /// let expr = User::id().cast("BIGINT")?; - /// let row = database.from::().eq(expr, 1_i64).get()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn cast(self, data_type: &str) -> Result { - ValueExpressionOps::cast_value(self, data_type) - } - - /// Casts this expression using an explicit SQL AST data type. - /// - /// ```rust,ignore - /// use sqlparser::ast::DataType; - /// - /// let expr = User::id().cast_to(DataType::BigInt(None)); - /// let row = database.from::().eq(expr, 1_i64).get()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn cast_to(self, data_type: DataType) -> QueryValue { - ValueExpressionOps::cast_to_value(self, data_type) - } - - /// Wraps a query builder as a scalar subquery expression. - /// - /// ```rust,ignore - /// let max_amount = kite_sql::orm::QueryValue::subquery( - /// database.from::().project_value(kite_sql::orm::max(Order::amount())), - /// ); - /// let row = database - /// .from::() - /// .eq(Order::amount(), max_amount) - /// .get()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn subquery(query: S) -> QueryValue { - QueryValue::from_expr(Expr::Subquery(Box::new(query.into_subquery()))) - } - - /// Builds `expr IN (subquery)`. - /// - /// ```rust,ignore - /// let expr = User::id().in_subquery( - /// database.from::().project_value(Order::user_id()), - /// ); - /// let users = database.from::().filter(expr).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn in_subquery(self, subquery: S) -> QueryExpr { - ValueExpressionOps::in_subquery_expr(self, subquery) - } - - /// Builds `expr NOT IN (subquery)`. - /// - /// ```rust,ignore - /// let expr = User::id().not_in_subquery( - /// database.from::().project_value(Order::user_id()), - /// ); - /// let users = database.from::().filter(expr).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn not_in_subquery(self, subquery: S) -> QueryExpr { - ValueExpressionOps::not_in_subquery_expr(self, subquery) - } - - fn asc(self) -> SortExpr { - SortExpr::new(self, false) - } - - fn desc(self) -> SortExpr { - SortExpr::new(self, true) - } - - fn function_with_args(name: N, args: I) -> Self - where - N: Into, - I: IntoIterator, - { - let name = name.into(); - Self::from_expr(Expr::Function(Function { - name: object_name(&name), - uses_odbc_syntax: false, - parameters: FunctionArguments::None, - args: FunctionArguments::List(FunctionArgumentList { - duplicate_treatment: None, - args: args.into_iter().map(FunctionArg::Unnamed).collect(), - clauses: vec![], - }), - filter: None, - null_treatment: None, - over: None, - within_group: vec![], - })) - } -} - -impl From> for QueryValue { - fn from(value: Field) -> Self { - value.value() - } -} - -impl ValueExpressionOps for Field { - fn into_query_value(self) -> QueryValue { - self.value() - } -} - -impl ValueExpressionOps for QueryValue { - fn into_query_value(self) -> QueryValue { - self - } -} - -impl ProjectedValue { - fn into_select_item(self) -> SelectItem { - self.item - } -} - -#[doc(hidden)] -pub fn projection_value( - column: &'static str, - relation: &str, - alias: &'static str, -) -> ProjectedValue { - qualified_column_value(relation, column).alias(alias) -} - -#[doc(hidden)] -pub fn projection_column(column: &'static str, relation: &str) -> ProjectedValue { - ProjectedValue::from(qualified_column_value(relation, column)) -} - -impl> From for ProjectedValue { - fn from(value: V) -> Self { - Self { - item: SelectItem::UnnamedExpr(value.into().into_expr()), - } - } -} - -macro_rules! impl_into_projected_tuple { - ($(($($name:ident),+)),+ $(,)?) => { - $( - impl<$($name),+> IntoProjectedTuple for ($($name,)+) - where - $($name: Into,)+ - { - #[allow(non_snake_case)] - fn into_projected_values(self) -> Vec { - let ($($name,)+) = self; - vec![$($name.into(),)+] - } - } - )+ - }; -} - -impl_into_projected_tuple!( - (A, B), - (A, B, C), - (A, B, C, D), - (A, B, C, D, E), - (A, B, C, D, E, F), - (A, B, C, D, E, F, G), - (A, B, C, D, E, F, G, H), -); - -impl From for QueryValue { - fn from(value: T) -> Self { - QueryValue::from_expr(data_value_to_ast_expr(&value.to_data_value())) - } -} - -impl CompareOp { - fn as_ast(&self) -> SqlBinaryOperator { - match self { - CompareOp::Eq => SqlBinaryOperator::Eq, - CompareOp::Ne => SqlBinaryOperator::NotEq, - CompareOp::Gt => SqlBinaryOperator::Gt, - CompareOp::Gte => SqlBinaryOperator::GtEq, - CompareOp::Lt => SqlBinaryOperator::Lt, - CompareOp::Lte => SqlBinaryOperator::LtEq, - } - } -} - -impl QuantifiedSubquery { - fn into_ast(self, left: Expr, compare_op: SqlBinaryOperator, right: Expr) -> Expr { - match self { - QuantifiedSubquery::Any | QuantifiedSubquery::Some => Expr::AnyOp { - left: Box::new(left), - compare_op, - right: Box::new(right), - is_some: matches!(self, QuantifiedSubquery::Some), - }, - QuantifiedSubquery::All => Expr::AllOp { - left: Box::new(left), - compare_op, - right: Box::new(right), - }, - } - } -} - -#[doc(hidden)] -pub trait StatementSource { - type Iter: ResultIter; - - /// Executes a prepared ORM statement with named parameters. - fn execute_statement>( - self, - statement: &Statement, - params: A, - ) -> Result; -} - -impl<'a, S: Storage> StatementSource for &'a Database { - type Iter = DatabaseIter<'a, S>; - - fn execute_statement>( - self, - statement: &Statement, - params: A, - ) -> Result { - self.execute(statement, params) - } -} - -impl<'a, 'tx, S: Storage> StatementSource for &'a mut DBTransaction<'tx, S> { - type Iter = TransactionIter<'a, S::TransactionType<'tx>>; - - fn execute_statement>( - self, - statement: &Statement, - params: A, - ) -> Result { - self.execute(statement, params) - } -} - -mod private { - pub trait Sealed {} -} - -#[doc(hidden)] -pub trait SubquerySource: private::Sealed { - fn into_subquery(self) -> Query; -} - -struct QueryBuilder { - state: BuilderState, - projection: P, -} - -/// Lightweight single-table query builder for ORM models. -/// -/// This is the main entry point returned by `Database::from::()` and -/// `DBTransaction::from::()`. -pub struct FromBuilder { - inner: QueryBuilder, -} - -#[doc(hidden)] -pub struct SetQueryBuilder { - source: Q, - query: Query, - _marker: PhantomData<(M, P)>, -} - -/// ORM update builder produced from [`FromBuilder::update`]. -/// -/// This builder currently supports single-table updates with optional `WHERE` -/// filters inherited from [`FromBuilder`]. -pub struct UpdateBuilder { - state: BuilderState, - assignments: Vec, -} - -#[doc(hidden)] -pub struct JoinOnBuilder { - inner: QueryBuilder, - join_source: QuerySource, - join_kind: JoinKind, -} - -#[doc(hidden)] -pub struct ModelProjection; - -#[doc(hidden)] -pub struct ValueProjection { - value: ProjectedValue, -} - -#[doc(hidden)] -pub struct TupleProjection { - values: Vec, -} - -#[doc(hidden)] -pub struct StructProjection { - _marker: PhantomData, -} - -struct BuilderState { - source: Q, - query_source: QuerySource, - joins: Vec, - distinct: bool, - filter: Option, - group_bys: Vec, - having: Option, - order_bys: Vec, - limit: Option, - offset: Option, - _marker: PhantomData, -} - -struct UpdateAssignment { - column: &'static str, - placeholder: &'static str, - value: UpdateAssignmentValue, -} - -#[allow(clippy::large_enum_variant)] -enum UpdateAssignmentValue { - Param(DataValue), - Expr(QueryValue), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum MutationKind { - Update, - Delete, -} - -impl BuilderState { - fn new(source: Q, query_source: QuerySource) -> Self { - Self { - source, - query_source, - joins: Vec::new(), - distinct: false, - filter: None, - group_bys: Vec::new(), - having: None, - order_bys: Vec::new(), - limit: None, - offset: None, - _marker: PhantomData, - } - } - - fn push_filter(mut self, expr: QueryExpr, mode: FilterMode) -> Self { - self.filter = Some(match (mode, self.filter.take()) { - (FilterMode::Replace, _) => expr, - (FilterMode::And, Some(current)) => current.and(expr), - (FilterMode::Or, Some(current)) => current.or(expr), - (_, None) => expr, - }); - self - } - - fn push_order(mut self, order: SortExpr) -> Self { - self.order_bys.push(order); - self - } - - fn push_group_by(mut self, expr: QueryValue) -> Self { - self.group_bys.push(expr); - self - } - - fn with_distinct(mut self) -> Self { - self.distinct = true; - self - } - - fn push_join(mut self, join: JoinSpec) -> Self { - self.joins.push(join); - self - } -} - -impl MutationKind { - fn as_str(self) -> &'static str { - match self { - MutationKind::Update => "update", - MutationKind::Delete => "delete", - } - } -} - -impl UpdateAssignment { - fn new(column: &'static str, placeholder: &'static str, value: UpdateAssignmentValue) -> Self { - Self { - column, - placeholder, - value, - } - } - - fn into_parts(self) -> (Assignment, Option<(&'static str, DataValue)>) { - let placeholder = self.placeholder; - match self.value { - UpdateAssignmentValue::Param(value) => ( - Assignment { - target: AssignmentTarget::ColumnName(object_name(self.column)), - value: placeholder_expr(placeholder), - }, - Some((placeholder, value)), - ), - UpdateAssignmentValue::Expr(value) => ( - Assignment { - target: AssignmentTarget::ColumnName(object_name(self.column)), - value: value.into_expr(), - }, - None, - ), - } - } -} - -#[doc(hidden)] -pub trait ProjectionSpec { - fn into_select_items(self, relation: &str) -> Vec; -} - -/// Declares a struct-backed ORM projection used by [`FromBuilder::project`]. -/// -/// This trait is typically derived with `#[derive(Projection)]`. -/// -/// ```rust,ignore -/// #[derive(Default, kite_sql::Projection)] -/// struct UserSummary { -/// id: i32, -/// #[projection(rename = "user_name")] -/// display_name: String, -/// } -/// -/// let rows = database.from::().project::().fetch()?; -/// # Ok::<(), kite_sql::errors::DatabaseError>(()) -/// ``` -pub trait Projection: for<'a> From<(&'a SchemaRef, Tuple)> { - /// Returns the projected select-list items for model `M`. - fn projected_values(relation: &str) -> Vec; -} - -#[doc(hidden)] -pub trait IntoProjectedTuple { - fn into_projected_values(self) -> Vec; -} - -#[doc(hidden)] -pub trait IntoJoinColumns { - fn into_join_columns(self) -> Vec; -} - -#[doc(hidden)] -pub trait IntoInsertColumns { - fn into_insert_columns(self) -> Vec; -} - -#[doc(hidden)] -pub trait QueryOperand: private::Sealed + Sized { - type Source: StatementSource; - type Model: Model; - type Projection; - type Shape; - - fn into_query_parts(self) -> (Self::Source, Query); - - fn into_query(self) -> Query { - self.into_query_parts().1 - } - - fn union(self, rhs: R) -> SetQueryBuilder - where - R: QueryOperand, - { - let (source, left_query) = self.into_query_parts(); - SetQueryBuilder::new( - source, - set_operation_query( - left_query, - rhs.into_query(), - SetOperator::Union, - SetQuantifier::Distinct, - ), - ) - } - - fn except(self, rhs: R) -> SetQueryBuilder - where - R: QueryOperand, - { - let (source, left_query) = self.into_query_parts(); - SetQueryBuilder::new( - source, - set_operation_query( - left_query, - rhs.into_query(), - SetOperator::Except, - SetQuantifier::Distinct, - ), - ) - } - - fn intersect(self, rhs: R) -> SetQueryBuilder - where - R: QueryOperand, - { - let (source, left_query) = self.into_query_parts(); - SetQueryBuilder::new( - source, - set_operation_query( - left_query, - rhs.into_query(), - SetOperator::Intersect, - SetQuantifier::Distinct, - ), - ) - } -} - -impl IntoJoinColumns for Field { - fn into_join_columns(self) -> Vec { - vec![object_name(self.column)] - } -} - -impl IntoInsertColumns for Field { - fn into_insert_columns(self) -> Vec { - vec![ident(self.column)] - } -} - -macro_rules! impl_into_join_columns { - ($(($($name:ident),+)),+ $(,)?) => { - $( - impl<$($name),+> IntoJoinColumns for ($($name,)+) - where - $($name: IntoJoinColumns,)+ - { - #[allow(non_snake_case)] - fn into_join_columns(self) -> Vec { - let ($($name,)+) = self; - let mut columns = Vec::new(); - $(columns.extend($name.into_join_columns());)+ - columns - } - } - )+ - }; -} - -impl_into_join_columns!( - (A, B), - (A, B, C), - (A, B, C, D), - (A, B, C, D, E), - (A, B, C, D, E, F), - (A, B, C, D, E, F, G), - (A, B, C, D, E, F, G, H), -); - -macro_rules! impl_into_insert_columns { - ($(($($name:ident),+)),+ $(,)?) => { - $( - impl IntoInsertColumns for ($($name,)+) - where - $($name: IntoInsertColumns,)+ - { - #[allow(non_snake_case)] - fn into_insert_columns(self) -> Vec { - let ($($name,)+) = self; - let mut columns = Vec::new(); - $(columns.extend($name.into_insert_columns());)+ - columns - } - } - )+ - }; -} - -impl_into_insert_columns!( - (A, B), - (A, B, C), - (A, B, C, D), - (A, B, C, D, E), - (A, B, C, D, E, F), - (A, B, C, D, E, F, G), - (A, B, C, D, E, F, G, H), -); - -impl ProjectionSpec for ModelProjection { - fn into_select_items(self, relation: &str) -> Vec { - select_projection(M::fields(), relation) - } -} - -impl ProjectionSpec for ValueProjection { - fn into_select_items(self, _relation: &str) -> Vec { - vec![self.value.into_select_item()] - } -} - -impl ProjectionSpec for TupleProjection { - fn into_select_items(self, _relation: &str) -> Vec { - self.values - .into_iter() - .map(ProjectedValue::into_select_item) - .collect() - } -} - -impl ProjectionSpec for StructProjection { - fn into_select_items(self, relation: &str) -> Vec { - T::projected_values::(relation) - .into_iter() - .map(ProjectedValue::into_select_item) - .collect() - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum FilterMode { - Replace, - And, - Or, -} - -impl QueryBuilder { - fn new(source: Q) -> Self { - Self { - state: BuilderState::new(source, QuerySource::model::()), - projection: ModelProjection, - } - } -} - -impl QueryBuilder { - fn with_projection(self, projection: P2) -> QueryBuilder { - QueryBuilder { - state: self.state, - projection, - } - } - - fn with_alias(mut self, alias: impl Into) -> Self { - self.state.query_source = self.state.query_source.with_alias(alias); - self - } - - fn into_query_parts(self) -> (Q, Query) - where - P: ProjectionSpec, - { - let QueryBuilder { - state: - BuilderState { - source, - query_source, - joins, - distinct, - filter, - group_bys, - having, - order_bys, - limit, - offset, - .. - }, - projection, - } = self; - - ( - source, - select_query( - &query_source, - joins, - projection.into_select_items(query_source.relation_name()), - distinct, - filter, - group_bys, - having, - order_bys, - limit, - offset, - ), - ) - } -} - -impl FromBuilder { - fn from_inner(inner: QueryBuilder) -> Self { - Self { inner } - } - - fn with_projection(self, projection: P2) -> FromBuilder { - FromBuilder::from_inner(self.inner.with_projection(projection)) - } - - /// Applies a relation alias to the current source. - /// - /// ```rust,ignore - /// let user = database - /// .from::() - /// .alias("u") - /// .eq(User::id().qualify("u"), 1) - /// .get()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn alias(self, alias: impl Into) -> Self { - FromBuilder::from_inner(self.inner.with_alias(alias)) - } - - /// Builds a `UNION` set query with another query of the same shape. - /// - /// ```rust,ignore - /// let ids = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn union(self, rhs: R) -> SetQueryBuilder - where - Self: QueryOperand, - R: QueryOperand::Shape>, - { - QueryOperand::union(self, rhs) - } - - /// Builds an `EXCEPT` set query with another query of the same shape. - /// - /// ```rust,ignore - /// let users_without_orders = database - /// .from::() - /// .project_value(User::id()) - /// .except(database.from::().project_value(Order::user_id())) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn except(self, rhs: R) -> SetQueryBuilder - where - Self: QueryOperand, - R: QueryOperand::Shape>, - { - QueryOperand::except(self, rhs) - } - - /// Builds an `INTERSECT` set query with another query of the same shape. - /// - /// ```rust,ignore - /// let ordered_user_ids = database - /// .from::() - /// .project_value(User::id()) - /// .intersect(database.from::().project_value(Order::user_id())) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn intersect(self, rhs: R) -> SetQueryBuilder - where - Self: QueryOperand, - R: QueryOperand::Shape>, - { - QueryOperand::intersect(self, rhs) - } - - /// Inserts the current query result into a target model table. - /// - /// Use this when the query output is a partial projection and you want to - /// choose the destination columns explicitly. - /// - /// ```rust,ignore - /// database - /// .from::() - /// .project_tuple((ArchivedUser::id(), ArchivedUser::name())) - /// .insert_into::((User::id(), User::name()))?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn insert_into>( - self, - columns: C, - ) -> Result<(), DatabaseError> - where - Self: QueryOperand, - { - let (source, query) = QueryOperand::into_query_parts(self); - execute_insert_query( - source, - orm_insert_query_statement( - Target::table_name(), - columns.into_insert_columns(), - query, - false, - ), - ) - } - - /// Inserts the current query result with `INSERT OVERWRITE` semantics. - /// - /// Use this when the query output is a partial projection and you want to - /// choose the destination columns explicitly. - pub fn overwrite_into>( - self, - columns: C, - ) -> Result<(), DatabaseError> - where - Self: QueryOperand, - { - let (source, query) = QueryOperand::into_query_parts(self); - execute_insert_query( - source, - orm_insert_query_statement( - Target::table_name(), - columns.into_insert_columns(), - query, - true, - ), - ) - } -} - -impl SetQueryBuilder { - fn new(source: Q, query: Query) -> Self { - Self { - source, - query, - _marker: PhantomData, - } - } - - /// Marks the preceding set operation as `ALL`. - /// - /// ```rust,ignore - /// let total = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .all() - /// .count()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn all(mut self) -> Self { - set_query_quantifier(&mut self.query, SetQuantifier::All); - self - } - - /// Appends an ascending sort key to the set query result. - /// - /// ```rust,ignore - /// let ids = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .asc(User::id()) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn asc>(mut self, value: V) -> Self { - query_push_order(&mut self.query, set_query_order_value(value.into()).asc()); - self - } - - /// Applies `NULLS FIRST` to the most recently added sort key. - /// - /// Tips: call this immediately after `asc(...)` or `desc(...)`. - /// Without a preceding sort key, this method is a no-op. - /// - /// ```rust,ignore - /// let ids = database - /// .from::() - /// .project_value(User::age()) - /// .asc(User::age()) - /// .nulls_first() - /// .fetch::>()?; - /// # let _ = ids; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn nulls_first(mut self) -> Self { - query_set_last_order_nulls(&mut self.query, true); - self - } - - /// Appends a descending sort key to the set query result. - /// - /// ```rust,ignore - /// let ids = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .desc(User::id()) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn desc>(mut self, value: V) -> Self { - query_push_order(&mut self.query, set_query_order_value(value.into()).desc()); - self - } - - /// Applies `NULLS LAST` to the most recently added sort key. - /// - /// Tips: call this immediately after `asc(...)` or `desc(...)`. - /// Without a preceding sort key, this method is a no-op. - /// - /// ```rust,ignore - /// let ids = database - /// .from::() - /// .project_value(User::age()) - /// .desc(User::age()) - /// .nulls_last() - /// .fetch::>()?; - /// # let _ = ids; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn nulls_last(mut self) -> Self { - query_set_last_order_nulls(&mut self.query, false); - self - } - - /// Sets the set query `LIMIT`. - /// - /// ```rust,ignore - /// let top_two = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .asc(User::id()) - /// .limit(2) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn limit(mut self, limit: usize) -> Self { - query_set_limit(&mut self.query, limit); - self - } - - /// Sets the set query `OFFSET`. - /// - /// ```rust,ignore - /// let skipped = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .asc(User::id()) - /// .offset(1) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn offset(mut self, offset: usize) -> Self { - query_set_offset(&mut self.query, offset); - self - } - - /// Executes the set query and returns the raw result iterator. - /// - /// ```rust,ignore - /// let rows = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .raw()?; - /// # rows.done()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn raw(self) -> Result { - execute_query(self.source, self.query) - } - - /// Returns the logical plan text for the current set query. - /// - /// ```rust,ignore - /// let plan = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .explain()?; - /// assert!(plan.contains("Union")); - /// # let _ = plan; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn explain(self) -> Result { - query_explain(self.source, self.query) - } - - /// Returns whether the set query produces at least one row. - /// - /// ```rust,ignore - /// let has_ids = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .exists()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn exists(self) -> Result { - query_exists(self.source, self.query) - } - - /// Returns the row count of the set query result. - /// - /// ```rust,ignore - /// let total = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .count()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn count(self) -> Result { - query_count(self.source, self.query) - } - - /// Appends `UNION` to the current set query. - /// - /// ```rust,ignore - /// let ids = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .union(database.from::().project_value(Wallet::id())) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn union(self, rhs: R) -> Self - where - Self: QueryOperand, - R: QueryOperand::Shape>, - { - QueryOperand::union(self, rhs) - } - - /// Appends `EXCEPT` to the current set query. - /// - /// ```rust,ignore - /// let ids = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .except(database.from::().project_value(Wallet::id())) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn except(self, rhs: R) -> Self - where - Self: QueryOperand, - R: QueryOperand::Shape>, - { - QueryOperand::except(self, rhs) - } - - /// Appends `INTERSECT` to the current set query. - /// - /// ```rust,ignore - /// let ids = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .intersect(database.from::().project_value(Wallet::id())) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn intersect(self, rhs: R) -> Self - where - Self: QueryOperand, - R: QueryOperand::Shape>, - { - QueryOperand::intersect(self, rhs) - } - - /// Inserts the current query result into a target model table. - /// - /// Use this when the query output is a partial projection and you want to - /// choose the destination columns explicitly. - /// - /// ```rust,ignore - /// database - /// .from::() - /// .project_tuple((ArchivedUser::id(), ArchivedUser::name())) - /// .insert_into::((User::id(), User::name()))?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn insert_into>( - self, - columns: C, - ) -> Result<(), DatabaseError> - where - Self: QueryOperand, - { - let (source, query) = QueryOperand::into_query_parts(self); - execute_insert_query( - source, - orm_insert_query_statement( - Target::table_name(), - columns.into_insert_columns(), - query, - false, - ), - ) - } - - /// Inserts the current set-query result with `INSERT OVERWRITE` semantics. - /// - /// Use this when the query output is a partial projection and you want to - /// choose the destination columns explicitly. - pub fn overwrite_into>( - self, - columns: C, - ) -> Result<(), DatabaseError> - where - Self: QueryOperand, - { - let (source, query) = QueryOperand::into_query_parts(self); - execute_insert_query( - source, - orm_insert_query_statement( - Target::table_name(), - columns.into_insert_columns(), - query, - true, - ), - ) - } -} - -impl UpdateBuilder { - fn new(state: BuilderState) -> Self { - Self { - state, - assignments: Vec::new(), - } - } - - fn with_assignment( - mut self, - column: &'static str, - placeholder: &'static str, - value: UpdateAssignmentValue, - ) -> Self { - let assignment = UpdateAssignment::new(column, placeholder, value); - if let Some(existing) = self - .assignments - .iter_mut() - .find(|existing| existing.column == column) - { - *existing = assignment; - } else { - self.assignments.push(assignment); - } - self - } - - /// Assigns a constant value to a model field. - /// - /// ```rust,ignore - /// database - /// .from::() - /// .eq(User::id(), 1) - /// .update() - /// .set(User::name(), "Bob") - /// .set(User::age(), Some(20)) - /// .execute()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn set(self, field: Field, value: V) -> Self { - let orm_field = field.orm_field(); - self.with_assignment( - orm_field.column, - orm_field.placeholder, - UpdateAssignmentValue::Param(value.to_data_value()), - ) - } - - /// Assigns a computed SQL expression to a model field. - /// - /// ```rust,ignore - /// database - /// .from::() - /// .eq(User::id(), 1) - /// .update() - /// .set_expr(User::age(), User::age().add(1)) - /// .execute()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn set_expr>(self, field: Field, value: V) -> Self { - let orm_field = field.orm_field(); - self.with_assignment( - orm_field.column, - orm_field.placeholder, - UpdateAssignmentValue::Expr(value.into()), - ) - } - - /// Executes the update statement. - pub fn execute(self) -> Result<(), DatabaseError> { - if self.assignments.is_empty() { - return Err(DatabaseError::ColumnsEmpty); - } - - validate_mutation_state(&self.state, MutationKind::Update)?; - - let UpdateBuilder { state, assignments } = self; - let BuilderState { - source, - query_source, - filter, - .. - } = state; - - let mut statement_assignments = Vec::with_capacity(assignments.len()); - let mut params = Vec::new(); - for assignment in assignments { - let (assignment, param) = assignment.into_parts(); - statement_assignments.push(assignment); - if let Some(param) = param { - params.push(param); - } - } - - let statement = orm_update_builder_statement(&query_source, filter, statement_assignments); - source.execute_statement(&statement, params)?.done() - } -} - -impl> JoinOnBuilder { - /// Applies a relation alias to the pending join source. - /// - /// This is mainly useful for self-joins or when you want explicit source - /// names in projected columns. - pub fn alias(mut self, alias: impl Into) -> Self { - self.join_source = self.join_source.with_alias(alias); - self - } - - /// Completes a pending join with an `ON` condition. - /// - /// ```rust,ignore - /// let rows = database - /// .from::() - /// .inner_join::() - /// .on(User::id().eq(Order::user_id())) - /// .project_tuple((User::name(), Order::amount())) - /// .fetch::<(String, i32)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn on(self, expr: QueryExpr) -> FromBuilder { - FromBuilder::from_inner(self.inner.join( - self.join_source, - self.join_kind, - JoinConstraint::On(expr.into_expr()), - )) - } - - /// Completes a pending join with a `USING (...)` column list. - /// - /// ```rust,ignore - /// let rows = database - /// .from::() - /// .inner_join::() - /// .using(User::id()) - /// .project_tuple((User::name(), Wallet::balance())) - /// .fetch::<(String, rust_decimal::Decimal)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn using(self, columns: C) -> FromBuilder { - FromBuilder::from_inner(self.inner.join( - self.join_source, - self.join_kind, - JoinConstraint::Using(columns.into_join_columns()), - )) - } -} - -impl FromBuilder { - /// Inserts full-row query results into another model table. - /// - /// This is the query-builder form of `INSERT INTO ... SELECT ...` when the - /// source query already yields all destination columns in order. - /// - /// ```rust,ignore - /// database.from::().insert::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn insert(self) -> Result<(), DatabaseError> { - let (source, query) = QueryOperand::into_query_parts(self); - execute_insert_query( - source, - orm_insert_query_statement( - Target::table_name(), - model_insert_columns::(), - query, - false, - ), - ) - } - - /// Inserts the current full-row query result with `INSERT OVERWRITE` semantics. - pub fn overwrite(self) -> Result<(), DatabaseError> { - let (source, query) = QueryOperand::into_query_parts(self); - execute_insert_query( - source, - orm_insert_query_statement( - Target::table_name(), - model_insert_columns::(), - query, - true, - ), - ) - } - - /// Starts a single-table `UPDATE` builder for the current model source. - /// - /// Chain one or more `.set(...)` or `.set_expr(...)` calls, then finish - /// with `.execute()`. - /// - /// ```rust,ignore - /// database - /// .from::() - /// .eq(User::id(), 1) - /// .update() - /// .set(User::name(), "Bob") - /// .execute()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn update(self) -> UpdateBuilder { - UpdateBuilder::new(self.inner.state) - } - - /// Executes a single-table `DELETE` for the current filtered source. - /// - /// ```rust,ignore - /// database - /// .from::() - /// .eq(User::id(), 1) - /// .delete()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn delete(self) -> Result<(), DatabaseError> { - self.inner.delete() - } - - /// Switches the query into a struct projection. - /// - /// ```rust,ignore - /// #[derive(Default, kite_sql::Projection)] - /// struct UserSummary { - /// id: i32, - /// #[projection(rename = "user_name")] - /// display_name: String, - /// } - /// - /// let users = database.from::().project::().fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn project(self) -> FromBuilder> { - self.with_projection(StructProjection { - _marker: PhantomData, - }) - } - - /// Switches the query into a single-value projection. - /// - /// ```rust,ignore - /// let ids = database - /// .from::() - /// .project_value(User::id()) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn project_value>( - self, - value: V, - ) -> FromBuilder { - self.with_projection(ValueProjection { - value: value.into(), - }) - } - - /// Switches the query into a tuple projection. - /// - /// ```rust,ignore - /// let rows = database - /// .from::() - /// .project_tuple((User::id(), User::name())) - /// .fetch::<(i32, String)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn project_tuple( - self, - values: V, - ) -> FromBuilder { - self.with_projection(TupleProjection { - values: values.into_projected_values(), - }) - } - - /// Executes the query and decodes rows into the model type. - pub fn fetch(self) -> Result, DatabaseError> { - Ok(self.raw()?.orm::()) - } - - /// Executes the query with `LIMIT 1` semantics and decodes one model row. - pub fn get(self) -> Result, DatabaseError> { - extract_optional_model(self.limit(1).raw()?) - } -} - -impl> FromBuilder { - /// Starts an `INNER JOIN` against another model source. - /// - /// Call `.on(...)` or `.using(...)` to supply the join condition. Use `.alias(...)` only - /// when explicit qualification is needed, such as self-joins. - pub fn inner_join(self) -> JoinOnBuilder { - JoinOnBuilder { - inner: self.inner, - join_source: QuerySource::model::(), - join_kind: JoinKind::Inner, - } - } - - /// Starts a `LEFT JOIN` against another model source. - /// - /// Call `.on(...)` or `.using(...)` to supply the join condition. Use `.alias(...)` only - /// when explicit qualification is needed, such as self-joins. - pub fn left_join(self) -> JoinOnBuilder { - JoinOnBuilder { - inner: self.inner, - join_source: QuerySource::model::(), - join_kind: JoinKind::Left, - } - } - - /// Starts a `RIGHT JOIN` against another model source. - /// - /// Call `.on(...)` or `.using(...)` to supply the join condition. Use `.alias(...)` only - /// when explicit qualification is needed, such as self-joins. - pub fn right_join(self) -> JoinOnBuilder { - JoinOnBuilder { - inner: self.inner, - join_source: QuerySource::model::(), - join_kind: JoinKind::Right, - } - } - - /// Starts a `FULL OUTER JOIN` against another model source. - /// - /// Call `.on(...)` or `.using(...)` to supply the join condition. Use `.alias(...)` only - /// when explicit qualification is needed, such as self-joins. - pub fn full_join(self) -> JoinOnBuilder { - JoinOnBuilder { - inner: self.inner, - join_source: QuerySource::model::(), - join_kind: JoinKind::Full, - } - } - - /// Starts a `CROSS JOIN` against another model source. - /// - /// `CROSS JOIN` does not take an `ON` or `USING` clause. - pub fn cross_join(self) -> Self { - Self::from_inner(self.inner.join( - QuerySource::model::(), - JoinKind::Cross, - JoinConstraint::None, - )) - } - - /// Replaces the current `WHERE` predicate. - /// - /// ```rust,ignore - /// let adults = database - /// .from::() - /// .filter(User::age().gte(18)) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn filter(self, expr: QueryExpr) -> Self { - Self::from_inner(self.inner.filter(expr)) - } - - /// Applies `SELECT DISTINCT` to the current query. - /// - /// ```rust,ignore - /// let categories = database - /// .from::() - /// .distinct() - /// .project_value(EventLog::category()) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn distinct(self) -> Self { - Self::from_inner(self.inner.distinct()) - } - - /// Appends `left AND right` to the current filter state. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .and(User::age().gte(18), User::name().like("A%")) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn and(self, left: QueryExpr, right: QueryExpr) -> Self { - Self::from_inner(self.inner.and(left, right)) - } - - /// Appends `left OR right` to the current filter state. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .or(User::name().like("A%"), User::name().like("B%")) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn or(self, left: QueryExpr, right: QueryExpr) -> Self { - Self::from_inner(self.inner.or(left, right)) - } - - /// Replaces the current filter with `NOT expr`. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .not(User::name().like("A%")) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn not(self, expr: QueryExpr) -> Self { - Self::from_inner(self.inner.not(expr)) - } - - /// Replaces the current filter with `EXISTS (subquery)`. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .where_exists( - /// database - /// .from::() - /// .project_value(Order::id()) - /// .eq(Order::user_id(), User::id()), - /// ) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn where_exists(self, subquery: S) -> Self { - Self::from_inner(self.inner.where_exists(subquery)) - } - - /// Replaces the current filter with `NOT EXISTS (subquery)`. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .where_not_exists( - /// database - /// .from::() - /// .project_value(Order::id()) - /// .eq(Order::user_id(), User::id()), - /// ) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn where_not_exists(self, subquery: S) -> Self { - Self::from_inner(self.inner.where_not_exists(subquery)) - } - - /// Appends a `GROUP BY` expression. - /// - /// ```rust,ignore - /// let rows = database - /// .from::() - /// .project_tuple((EventLog::category(), kite_sql::orm::count(EventLog::id()))) - /// .group_by(EventLog::category()) - /// .fetch::<(String, i32)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn group_by>(self, value: V) -> Self { - Self::from_inner(self.inner.group_by(value)) - } - - /// Sets the `HAVING` predicate. - /// - /// ```rust,ignore - /// let rows = database - /// .from::() - /// .project_tuple((EventLog::category(), kite_sql::orm::count(EventLog::id()))) - /// .group_by(EventLog::category()) - /// .having(kite_sql::orm::count(EventLog::id()).gt(10)) - /// .fetch::<(String, i32)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn having(self, expr: QueryExpr) -> Self { - Self::from_inner(self.inner.having(expr)) - } - - /// Appends an ascending sort key. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .asc(User::name()) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn asc>(self, value: V) -> Self { - Self::from_inner(self.inner.asc(value)) - } - - /// Applies `NULLS FIRST` to the most recently added sort key. - /// - /// Tips: call this immediately after `asc(...)` or `desc(...)`. - /// Without a preceding sort key, this method is a no-op. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .asc(User::age()) - /// .nulls_first() - /// .fetch()?; - /// # let _ = users; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn nulls_first(self) -> Self { - Self::from_inner(self.inner.nulls_first()) - } - - /// Applies `NULLS LAST` to the most recently added sort key. - /// - /// Tips: call this immediately after `asc(...)` or `desc(...)`. - /// Without a preceding sort key, this method is a no-op. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .desc(User::age()) - /// .nulls_last() - /// .fetch()?; - /// # let _ = users; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn nulls_last(self) -> Self { - Self::from_inner(self.inner.nulls_last()) - } - - /// Appends a descending sort key. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .desc(User::id()) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn desc>(self, value: V) -> Self { - Self::from_inner(self.inner.desc(value)) - } - - /// Sets the query `LIMIT`. - /// - /// ```rust,ignore - /// let first_two = database.from::().asc(User::id()).limit(2).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn limit(self, limit: usize) -> Self { - Self::from_inner(self.inner.limit(limit)) - } - - /// Sets the query `OFFSET`. - /// - /// ```rust,ignore - /// let later_users = database.from::().asc(User::id()).offset(1).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn offset(self, offset: usize) -> Self { - Self::from_inner(self.inner.offset(offset)) - } - - /// Appends `left = right` to the filter state. - /// - /// ```rust,ignore - /// let user = database.from::().eq(User::id(), 1).get()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn eq, R: Into>(self, left: L, right: R) -> Self { - Self::from_inner(self.inner.eq(left, right)) - } - - /// Appends `left <> right` to the filter state. - /// - /// ```rust,ignore - /// let users = database.from::().ne(User::id(), 1).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn ne, R: Into>(self, left: L, right: R) -> Self { - Self::from_inner(self.inner.ne(left, right)) - } - - /// Appends `left > right` to the filter state. - /// - /// ```rust,ignore - /// let adults = database.from::().gt(User::age(), 18).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn gt, R: Into>(self, left: L, right: R) -> Self { - Self::from_inner(self.inner.gt(left, right)) - } - - /// Appends `left >= right` to the filter state. - /// - /// ```rust,ignore - /// let adults = database.from::().gte(User::age(), 18).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn gte, R: Into>(self, left: L, right: R) -> Self { - Self::from_inner(self.inner.gte(left, right)) - } - - /// Appends `left < right` to the filter state. - /// - /// ```rust,ignore - /// let younger = database.from::().lt(User::age(), 18).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn lt, R: Into>(self, left: L, right: R) -> Self { - Self::from_inner(self.inner.lt(left, right)) - } - - /// Appends `left <= right` to the filter state. - /// - /// ```rust,ignore - /// let users = database.from::().lte(User::id(), 10).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn lte, R: Into>(self, left: L, right: R) -> Self { - Self::from_inner(self.inner.lte(left, right)) - } - - /// Appends `value IS NULL` to the filter state. - /// - /// ```rust,ignore - /// let users = database.from::().is_null(User::age()).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn is_null>(self, value: V) -> Self { - Self::from_inner(self.inner.is_null(value)) - } - - /// Appends `value IS NOT NULL` to the filter state. - /// - /// ```rust,ignore - /// let users = database.from::().is_not_null(User::age()).fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn is_not_null>(self, value: V) -> Self { - Self::from_inner(self.inner.is_not_null(value)) - } - - /// Appends `value LIKE pattern` to the filter state. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .like(User::name(), "A%") - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn like, R: Into>(self, value: L, pattern: R) -> Self { - Self::from_inner(self.inner.like(value, pattern)) - } - - /// Appends `value NOT LIKE pattern` to the filter state. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .not_like(User::name(), "A%") - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn not_like, R: Into>(self, value: L, pattern: R) -> Self { - Self::from_inner(self.inner.not_like(value, pattern)) - } - - /// Appends `left IN (...)` to the filter state. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .in_list(User::id(), [1, 2, 3]) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn in_list(self, left: L, values: I) -> Self - where - L: Into, - I: IntoIterator, - V: Into, - { - Self::from_inner(self.inner.in_list(left, values)) - } - - /// Appends `left NOT IN (...)` to the filter state. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .not_in_list(User::id(), [1, 2, 3]) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn not_in_list(self, left: L, values: I) -> Self - where - L: Into, - I: IntoIterator, - V: Into, - { - Self::from_inner(self.inner.not_in_list(left, values)) - } - - /// Appends `expr BETWEEN low AND high` to the filter state. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .between(User::age(), 18, 30) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn between(self, expr: L, low: Low, high: High) -> Self - where - L: Into, - Low: Into, - High: Into, - { - Self::from_inner(self.inner.between(expr, low, high)) - } - - /// Appends `expr NOT BETWEEN low AND high` to the filter state. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .not_between(User::age(), 18, 30) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn not_between(self, expr: L, low: Low, high: High) -> Self - where - L: Into, - Low: Into, - High: Into, - { - Self::from_inner(self.inner.not_between(expr, low, high)) - } - - /// Appends `left IN (subquery)` to the filter state. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .in_subquery( - /// User::id(), - /// database.from::().project_value(Order::user_id()), - /// ) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn in_subquery, S: SubquerySource>(self, left: L, subquery: S) -> Self { - Self::from_inner(self.inner.in_subquery(left, subquery)) - } - - /// Appends `left NOT IN (subquery)` to the filter state. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .not_in_subquery( - /// User::id(), - /// database.from::().project_value(Order::user_id()), - /// ) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn not_in_subquery, S: SubquerySource>( - self, - left: L, - subquery: S, - ) -> Self { - Self::from_inner(self.inner.not_in_subquery(left, subquery)) - } - - /// Executes the query and returns the raw result iterator. - /// - /// This is mainly useful when you want access to the raw tuple/schema pair - /// instead of ORM decoding. - /// - /// ```rust,ignore - /// let rows = database.from::().eq(User::id(), 1).raw()?; - /// # rows.done()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn raw(self) -> Result { - self.inner.raw() - } - - /// Returns the logical plan text for the current query. - /// - /// ```rust,ignore - /// let plan = database - /// .from::() - /// .eq(User::id(), 1) - /// .project_value(User::name()) - /// .explain()?; - /// assert!(plan.contains("TableScan")); - /// # let _ = plan; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn explain(self) -> Result { - self.inner.explain() - } - - /// Returns whether the query produces at least one row. - /// - /// ```rust,ignore - /// let found = database.from::().eq(User::id(), 1).exists()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn exists(self) -> Result { - self.inner.exists() - } - - /// Returns the row count for the current query shape. - /// - /// ```rust,ignore - /// let adult_count = database.from::().gte(User::age(), 18).count()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn count(self) -> Result { - self.inner.count() - } -} - -impl FromBuilder { - /// Executes a single-value projection and decodes each row into `T`. - /// - /// ```rust,ignore - /// let ids = database.from::().project_value(User::id()).fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn fetch(self) -> Result, DatabaseError> { - Ok(ProjectValueIter::new(self.raw()?)) - } - - /// Executes a single-value projection and decodes one value. - /// - /// ```rust,ignore - /// let first_id = database - /// .from::() - /// .project_value(User::id()) - /// .asc(User::id()) - /// .get::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn get(self) -> Result, DatabaseError> { - extract_optional_value(self.limit(1).raw()?) - } -} - -impl FromBuilder { - /// Executes a tuple projection and decodes each row into `T`. - /// - /// ```rust,ignore - /// let rows = database - /// .from::() - /// .project_tuple((User::id(), User::name())) - /// .fetch::<(i32, String)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn fetch(self) -> Result, DatabaseError> { - Ok(ProjectTupleIter::new(self.raw()?)) - } - - /// Executes a tuple projection and decodes one row into `T`. - /// - /// ```rust,ignore - /// let row = database - /// .from::() - /// .project_tuple((User::id(), User::name())) - /// .asc(User::id()) - /// .get::<(i32, String)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn get(self) -> Result, DatabaseError> { - extract_optional_tuple(self.limit(1).raw()?) - } -} - -impl FromBuilder> { - /// Executes a struct projection and decodes each row into `T`. - /// - /// ```rust,ignore - /// #[derive(Default, kite_sql::Projection)] - /// struct UserSummary { - /// id: i32, - /// #[projection(rename = "user_name")] - /// display_name: String, - /// } - /// - /// let rows = database.from::().project::().fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn fetch(self) -> Result, DatabaseError> { - Ok(self.raw()?.orm::()) - } - - /// Executes a struct projection and decodes one row into `T`. - /// - /// ```rust,ignore - /// #[derive(Default, kite_sql::Projection)] - /// struct UserSummary { - /// id: i32, - /// #[projection(rename = "user_name")] - /// display_name: String, - /// } - /// - /// let row = database.from::().project::().get()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn get(self) -> Result, DatabaseError> { - extract_optional_row(self.limit(1).raw()?) - } -} - -impl SetQueryBuilder { - /// Inserts full-row set-query results into another model table. - /// - /// ```rust,ignore - /// database - /// .from::() - /// .union(database.from::()) - /// .insert::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn insert(self) -> Result<(), DatabaseError> { - let (source, query) = QueryOperand::into_query_parts(self); - execute_insert_query( - source, - orm_insert_query_statement( - Target::table_name(), - model_insert_columns::(), - query, - false, - ), - ) - } - - /// Inserts the current full-row set-query result with `INSERT OVERWRITE` semantics. - pub fn overwrite(self) -> Result<(), DatabaseError> { - let (source, query) = QueryOperand::into_query_parts(self); - execute_insert_query( - source, - orm_insert_query_statement( - Target::table_name(), - model_insert_columns::(), - query, - true, - ), - ) - } - - /// Executes the set query and decodes rows into the model type. - /// - /// ```rust,ignore - /// let users = database - /// .from::() - /// .union(database.from::()) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn fetch(self) -> Result, DatabaseError> { - Ok(self.raw()?.orm::()) - } - - /// Executes the set query with `LIMIT 1` semantics and decodes one model row. - /// - /// ```rust,ignore - /// let user = database - /// .from::() - /// .union(database.from::()) - /// .get()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn get(self) -> Result, DatabaseError> { - extract_optional_model(self.limit(1).raw()?) - } -} - -impl SetQueryBuilder { - /// Executes the set query and decodes each row into `T`. - /// - /// ```rust,ignore - /// let ids = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .fetch::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn fetch(self) -> Result, DatabaseError> { - Ok(ProjectValueIter::new(self.raw()?)) - } - - /// Executes the set query with `LIMIT 1` semantics and decodes one value. - /// - /// ```rust,ignore - /// let id = database - /// .from::() - /// .project_value(User::id()) - /// .union(database.from::().project_value(Order::user_id())) - /// .asc(User::id()) - /// .get::()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn get(self) -> Result, DatabaseError> { - extract_optional_value(self.limit(1).raw()?) - } -} - -impl SetQueryBuilder { - /// Executes the set query and decodes each row into `T`. - /// - /// ```rust,ignore - /// let rows = database - /// .from::() - /// .project_tuple((User::id(), User::name())) - /// .union(database.from::().project_tuple((ArchivedUser::id(), ArchivedUser::name()))) - /// .fetch::<(i32, String)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn fetch(self) -> Result, DatabaseError> { - Ok(ProjectTupleIter::new(self.raw()?)) - } - - /// Executes the set query with `LIMIT 1` semantics and decodes one tuple row. - /// - /// ```rust,ignore - /// let row = database - /// .from::() - /// .project_tuple((User::id(), User::name())) - /// .union(database.from::().project_tuple((ArchivedUser::id(), ArchivedUser::name()))) - /// .get::<(i32, String)>()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn get(self) -> Result, DatabaseError> { - extract_optional_tuple(self.limit(1).raw()?) - } -} - -impl SetQueryBuilder> { - /// Executes the set query and decodes rows into the struct projection type. - /// - /// ```rust,ignore - /// #[derive(Default, kite_sql::Projection)] - /// struct UserSummary { - /// id: i32, - /// #[projection(rename = "user_name")] - /// display_name: String, - /// } - /// - /// let rows = database - /// .from::() - /// .project::() - /// .union(database.from::().project::()) - /// .fetch()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn fetch(self) -> Result, DatabaseError> { - Ok(self.raw()?.orm::()) - } - - /// Executes the set query with `LIMIT 1` semantics and decodes one projected row. - /// - /// ```rust,ignore - /// #[derive(Default, kite_sql::Projection)] - /// struct UserSummary { - /// id: i32, - /// #[projection(rename = "user_name")] - /// display_name: String, - /// } - /// - /// let row = database - /// .from::() - /// .project::() - /// .union(database.from::().project::()) - /// .get()?; - /// # Ok::<(), kite_sql::errors::DatabaseError>(()) - /// ``` - pub fn get(self) -> Result, DatabaseError> { - extract_optional_row(self.limit(1).raw()?) - } -} - -impl> QueryBuilder { - fn push_filter(self, expr: QueryExpr, mode: FilterMode) -> Self { - Self { - state: self.state.push_filter(expr, mode), - projection: self.projection, - } - } - - fn join( - self, - join_source: QuerySource, - join_kind: JoinKind, - constraint: JoinConstraint, - ) -> Self { - Self { - state: self.state.push_join(JoinSpec { - source: join_source, - kind: join_kind, - constraint, - }), - projection: self.projection, - } - } - - fn filter(mut self, expr: QueryExpr) -> Self { - self.state.filter = Some(expr); - self - } - - fn distinct(self) -> Self { - Self { - state: self.state.with_distinct(), - projection: self.projection, - } - } - - fn and(self, left: QueryExpr, right: QueryExpr) -> Self { - Self { - state: self.state.push_filter(left.and(right), FilterMode::And), - projection: self.projection, - } - } - - fn or(self, left: QueryExpr, right: QueryExpr) -> Self { - Self { - state: self.state.push_filter(left.or(right), FilterMode::Or), - projection: self.projection, - } - } - - fn not(self, expr: QueryExpr) -> Self { - Self { - state: self.state.push_filter(expr.not(), FilterMode::Replace), - projection: self.projection, - } - } - - fn where_exists(self, subquery: S) -> Self { - Self { - state: self - .state - .push_filter(QueryExpr::exists(subquery), FilterMode::Replace), - projection: self.projection, - } - } - - fn where_not_exists(self, subquery: S) -> Self { - Self { - state: self - .state - .push_filter(QueryExpr::not_exists(subquery), FilterMode::Replace), - projection: self.projection, - } - } - - fn group_by>(self, value: V) -> Self { - Self { - state: self.state.push_group_by(value.into()), - projection: self.projection, - } - } - - fn having(mut self, expr: QueryExpr) -> Self { - self.state.having = Some(expr); - self - } - - fn asc>(self, value: V) -> Self { - Self { - state: self.state.push_order(value.into().asc()), - projection: self.projection, - } - } - - fn nulls_first(mut self) -> Self { - if let Some(last) = self.state.order_bys.last_mut() { - *last = last.clone().with_nulls(true); - } - self - } - - fn nulls_last(mut self) -> Self { - if let Some(last) = self.state.order_bys.last_mut() { - *last = last.clone().with_nulls(false); - } - self - } - - fn desc>(self, value: V) -> Self { - Self { - state: self.state.push_order(value.into().desc()), - projection: self.projection, - } - } - - fn limit(mut self, limit: usize) -> Self { - self.state.limit = Some(limit); - self + pub fn case_when( + &self, + expr_pairs: impl IntoIterator, + else_expr: Option, + ) -> CtxExpression<'bind, 'parent, 'arena, T, A> + where + C: IntoOrmScalarExpression, + V: IntoOrmScalarExpression, + E: IntoOrmScalarExpression, + { + let expr_pairs = expr_pairs + .into_iter() + .map(|(condition, value)| (condition.into_orm_scalar(), value.into_orm_scalar())) + .collect::>(); + let else_expr = else_expr.map(IntoOrmScalarExpression::into_orm_scalar); + let ty = expr_pairs + .first() + .map(|(_, value)| value.return_type(self.arena).into_owned()) + .or_else(|| { + else_expr + .as_ref() + .map(|value| value.return_type(self.arena).into_owned()) + }) + .unwrap_or(LogicalType::SqlNull); + self.wrap(ScalarExpression::CaseWhen { + operand_expr: None, + expr_pairs, + else_expr: else_expr.map(Box::new), + ty, + }) } - fn offset(mut self, offset: usize) -> Self { - self.state.offset = Some(offset); - self + pub fn case_value( + &self, + operand_expr: impl IntoOrmScalarExpression, + expr_pairs: impl IntoIterator, + else_expr: Option, + ) -> CtxExpression<'bind, 'parent, 'arena, T, A> + where + K: IntoOrmScalarExpression, + V: IntoOrmScalarExpression, + E: IntoOrmScalarExpression, + { + let expr_pairs = expr_pairs + .into_iter() + .map(|(key, value)| (key.into_orm_scalar(), value.into_orm_scalar())) + .collect::>(); + let else_expr = else_expr.map(IntoOrmScalarExpression::into_orm_scalar); + let ty = expr_pairs + .first() + .map(|(_, value)| value.return_type(self.arena).into_owned()) + .or_else(|| { + else_expr + .as_ref() + .map(|value| value.return_type(self.arena).into_owned()) + }) + .unwrap_or(LogicalType::SqlNull); + self.wrap(ScalarExpression::CaseWhen { + operand_expr: Some(Box::new(operand_expr.into_orm_scalar())), + expr_pairs, + else_expr: else_expr.map(Box::new), + ty, + }) } - fn build_query(self) -> Query { - let QueryBuilder { - state: - BuilderState { - source: _, - query_source, - joins, - distinct, - filter, - group_bys, - having, - order_bys, - limit, - offset, - .. - }, - projection, - } = self; - - select_query( - &query_source, - joins, - projection.into_select_items(query_source.relation_name()), - distinct, - filter, - group_bys, - having, - order_bys, - limit, - offset, + pub fn scalar_subquery( + &self, + build: F, + ) -> Result, DatabaseError> + where + F: for<'scope, 'sub_bind, 'sub_parent> FnOnce( + &'scope mut OrmContext<'scope, 'sub_bind, 'sub_parent, 'arena, T, A>, ) + -> Result, + { + self.handle().scalar_subquery(build) } - fn into_statement(self) -> (Q, Statement) { - let QueryBuilder { - state: - BuilderState { - source, - query_source, - joins, - distinct, - filter, - group_bys, - having, - order_bys, - limit, - offset, - .. - }, - projection, - } = self; - - let statement = Statement::Query(Box::new(select_query( - &query_source, - joins, - projection.into_select_items(query_source.relation_name()), - distinct, - filter, - group_bys, - having, - order_bys, - limit, - offset, - ))); - - (source, statement) - } - - fn eq, R: Into>(self, left: L, right: R) -> Self { - self.push_filter(left.into().eq(right), FilterMode::Replace) - } - - fn ne, R: Into>(self, left: L, right: R) -> Self { - self.push_filter(left.into().ne(right), FilterMode::Replace) - } - - fn gt, R: Into>(self, left: L, right: R) -> Self { - self.push_filter(left.into().gt(right), FilterMode::Replace) - } - - fn gte, R: Into>(self, left: L, right: R) -> Self { - self.push_filter(left.into().gte(right), FilterMode::Replace) - } - - fn lt, R: Into>(self, left: L, right: R) -> Self { - self.push_filter(left.into().lt(right), FilterMode::Replace) - } - - fn lte, R: Into>(self, left: L, right: R) -> Self { - self.push_filter(left.into().lte(right), FilterMode::Replace) - } - - #[allow(clippy::wrong_self_convention)] - fn is_null>(self, value: V) -> Self { - self.push_filter(value.into().is_null(), FilterMode::Replace) - } - - #[allow(clippy::wrong_self_convention)] - fn is_not_null>(self, value: V) -> Self { - self.push_filter(value.into().is_not_null(), FilterMode::Replace) - } - - fn like, R: Into>(self, value: L, pattern: R) -> Self { - self.push_filter(value.into().like(pattern), FilterMode::Replace) - } - - fn not_like, R: Into>(self, value: L, pattern: R) -> Self { - self.push_filter(value.into().not_like(pattern), FilterMode::Replace) - } - - fn in_list(self, left: L, values: I) -> Self + pub fn exists_subquery( + &self, + negated: bool, + build: F, + ) -> Result, DatabaseError> where - L: Into, - I: IntoIterator, - V: Into, + F: for<'scope, 'sub_bind, 'sub_parent> FnOnce( + &'scope mut OrmContext<'scope, 'sub_bind, 'sub_parent, 'arena, T, A>, + ) + -> Result, { - Self { - state: self - .state - .push_filter(left.into().in_list(values), FilterMode::Replace), - projection: self.projection, - } + self.handle().exists_subquery(negated, build) } +} - fn not_in_list(self, left: L, values: I) -> Self +impl<'ctx, 'bind, 'parent, 'arena, T, A> UpdateBindScope<'ctx, 'bind, 'parent, 'arena, T, A> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + pub fn set_value(&mut self, field: Field, value: D) -> Result<(), DatabaseError> where - L: Into, - I: IntoIterator, - V: Into, + D: ToDataValue, { - Self { - state: self - .state - .push_filter(left.into().not_in_list(values), FilterMode::Replace), - projection: self.projection, - } + let expr = ScalarExpression::Constant(value.to_data_value()); + self.push_assignment(field.column, expr) } - fn between(self, expr: L, low: Low, high: High) -> Self + pub fn set(&mut self, field: Field, value: D) -> Result<(), DatabaseError> where - L: Into, - Low: Into, - High: Into, + D: ToDataValue, { - Self { - state: self - .state - .push_filter(expr.into().between(low, high), FilterMode::Replace), - projection: self.projection, - } + self.set_value(field, value) + } + + pub fn set_bound_expr( + &mut self, + field: Field, + build: impl BindOrmScalar<'bind, 'parent, 'arena, T, A>, + ) -> Result<(), DatabaseError> { + let expr = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = ExprBindScope { + binder: self.binder, + arena: self.arena, + }; + build.bind_scalar(&mut scope)? + }); + self.push_assignment(field.column, expr?) } - fn not_between(self, expr: L, low: Low, high: High) -> Self + pub fn set_expr( + &mut self, + field: Field, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result<(), DatabaseError> where - L: Into, - Low: Into, - High: Into, + E: IntoOrmScalarExpression, { - Self { - state: self - .state - .push_filter(expr.into().not_between(low, high), FilterMode::Replace), - projection: self.projection, - } - } - - fn in_subquery, S: SubquerySource>(self, left: L, subquery: S) -> Self { - Self { - state: self - .state - .push_filter(left.into().in_subquery(subquery), FilterMode::Replace), - projection: self.projection, - } - } - - fn not_in_subquery, S: SubquerySource>(self, left: L, subquery: S) -> Self { - Self { - state: self - .state - .push_filter(left.into().not_in_subquery(subquery), FilterMode::Replace), - projection: self.projection, + let expr = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = ExprBindScope { + binder: self.binder, + arena: self.arena, + }; + build(&mut scope)?.into_orm_scalar() + }); + self.push_assignment(field.column, expr?) + } + + fn push_assignment( + &mut self, + column_name: &str, + mut expr: ScalarExpression, + ) -> Result<(), DatabaseError> { + let column = + bind_orm_target_column(self.binder, &self.source_name, column_name, self.arena)?; + if matches!(expr, ScalarExpression::Empty) { + let column_catalog = self.arena.column(column); + let default_value = column_catalog + .default_value()? + .ok_or(DatabaseError::DefaultNotExist)?; + expr = ScalarExpression::Constant(default_value); } + let column_catalog = self.arena.column(column); + expr = ScalarExpression::type_cast( + expr, + Cow::Borrowed(column_catalog.datatype()), + self.arena, + )?; + self.value_exprs.push((column, expr)); + Ok(()) } - fn raw(self) -> Result { - let (source, statement) = self.into_statement(); - match statement { - Statement::Query(query) => execute_query(source, *query), - _ => source.execute_statement(&statement, &[]), + fn finish( + self, + table_name: TableName, + plan: LogicalPlan, + ) -> Result { + self.binder.context.allow_default = false; + if self.value_exprs.is_empty() { + return Err(DatabaseError::ColumnsEmpty); } + self.binder.bind_update(table_name, self.value_exprs, plan) } +} - fn explain(self) -> Result { - let (source, query) = self.into_query_parts(); - query_explain(source, query) +impl<'scope_ctx, 'bind, 'parent, 'arena, T, A, M> + BindPlanFrom<'scope_ctx, 'bind, 'parent, 'arena, T, A, M> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, + M: Model, +{ + fn model_table_name(&self) -> Result { + Ok(M::table_name().into()) } - fn exists(self) -> Result { - let mut iter = self.limit(1).raw()?; - Ok(iter.next().transpose()?.is_some()) + fn model_relation_name(&self) -> Result { + let table_name = M::table_name(); + self.binder + .context + .bind_table + .iter() + .rev() + .find(|source| source.table_name.as_ref() == table_name) + .map(|source| source.visible_name().to_string()) + .ok_or_else(|| DatabaseError::invalid_table(table_name)) } - fn count(self) -> Result { - let is_shape_sensitive = self.state.distinct - || !self.state.group_bys.is_empty() - || self.state.having.is_some() - || self.state.limit.is_some() - || self.state.offset.is_some(); - if is_shape_sensitive { - let mut iter = self.raw()?; - let mut count = 0usize; - while iter.next().transpose()?.is_some() { - count += 1; - } - iter.done()?; - return Ok(count); + fn expr_scope<'scope>(&'scope mut self) -> ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A> { + ExprBindScope { + binder: self.binder, + arena: self.arena, } - - let BuilderState { - source, - query_source, - joins, - filter, - .. - } = self.state; - let statement = orm_count_statement(&query_source, joins, filter); - let mut iter = source.execute_statement(&statement, &[])?; - let count = match iter.next().transpose()? { - Some(tuple) => match tuple.values.first() { - Some(DataValue::Int32(value)) => *value as usize, - Some(DataValue::Int64(value)) => *value as usize, - Some(DataValue::UInt32(value)) => *value as usize, - Some(DataValue::UInt64(value)) => *value as usize, - other => { - return Err(DatabaseError::InvalidValue(format!( - "unexpected count result: {other:?}" - ))) - } - }, - None => 0, - }; - iter.done()?; - Ok(count) - } - - fn delete(self) -> Result<(), DatabaseError> { - validate_mutation_state(&self.state, MutationKind::Delete)?; - - let BuilderState { - source, - query_source, - filter, - .. - } = self.state; - - source - .execute_statement(&orm_delete_builder_statement(&query_source, filter), &[])? - .done() - } -} - -impl private::Sealed for FromBuilder {} - -impl private::Sealed for SetQueryBuilder {} - -impl> SubquerySource for FromBuilder { - fn into_subquery(self) -> Query { - self.inner.build_query() } -} - -impl SubquerySource for SetQueryBuilder { - fn into_subquery(self) -> Query { - self.query - } -} -impl QueryOperand for FromBuilder { - type Source = Q; - type Model = M; - type Projection = ModelProjection; - type Shape = M; - - fn into_query_parts(self) -> (Self::Source, Query) { - self.inner.into_query_parts() + pub fn filter( + mut self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + let predicate = with_query_bind_step!(self.binder, QueryBindStep::Where, { + let mut scope = self.expr_scope(); + build(&mut scope)?.into_orm_scalar() + }); + let predicate = predicate?; + self.filter_expr(predicate) } -} -impl QueryOperand for FromBuilder { - type Source = Q; - type Model = M; - type Projection = ValueProjection; - type Shape = ValueProjection; - - fn into_query_parts(self) -> (Self::Source, Query) { - self.inner.into_query_parts() + fn join_with( + self, + join_type: JoinType, + alias: Option, + constraint: JoinConstraintInput, + ) -> Result { + let source = match alias { + Some(alias) => QuerySource::model::().with_alias(alias), + None => QuerySource::model::(), + }; + let (right_plan, right_context) = { + let mut right_binder = Binder::new( + self.binder.context.fork_empty(), + self.binder.args, + Some(&self.binder.context), + ); + let right_plan = + bind_orm_source(&mut right_binder, source, Some(join_type), self.arena)?; + (right_plan, right_binder.context) + }; + self.join_plan(right_plan, right_context, join_type, constraint) } -} - -impl QueryOperand for FromBuilder { - type Source = Q; - type Model = M; - type Projection = TupleProjection; - type Shape = TupleProjection; - fn into_query_parts(self) -> (Self::Source, Query) { - self.inner.into_query_parts() + fn join_on( + mut self, + join_type: JoinType, + alias: Option, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + let source = match alias { + Some(alias) => QuerySource::model::().with_alias(alias), + None => QuerySource::model::(), + }; + let (right_plan, right_context) = { + let mut right_binder = Binder::new( + self.binder.context.fork_empty(), + self.binder.args, + Some(&self.binder.context), + ); + let right_plan = + bind_orm_source(&mut right_binder, source, Some(join_type), self.arena)?; + (right_plan, right_binder.context) + }; + self.binder.extend(right_context); + let on = with_query_bind_step!(self.binder, QueryBindStep::From, { + let mut scope = self.expr_scope(); + build(&mut scope)?.into_orm_scalar() + }); + self.plan = self.binder.bind_join_plans( + self.plan, + right_plan, + join_type, + JoinConstraintInput::On(on?), + self.arena, + )?; + Ok(self) } -} - -impl QueryOperand - for FromBuilder> -{ - type Source = Q; - type Model = M; - type Projection = StructProjection; - type Shape = T; - fn into_query_parts(self) -> (Self::Source, Query) { - self.inner.into_query_parts() + pub fn inner_join( + self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + self.join_on::(JoinType::Inner, None, build) } -} - -impl QueryOperand for SetQueryBuilder { - type Source = Q; - type Model = M; - type Projection = ModelProjection; - type Shape = M; - fn into_query_parts(self) -> (Self::Source, Query) { - (self.source, self.query) + pub fn inner_join_as( + self, + alias: impl Into, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + self.join_on::(JoinType::Inner, Some(alias.into()), build) } -} -impl QueryOperand for SetQueryBuilder { - type Source = Q; - type Model = M; - type Projection = ValueProjection; - type Shape = ValueProjection; - - fn into_query_parts(self) -> (Self::Source, Query) { - (self.source, self.query) + pub fn left_join( + self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + self.join_on::(JoinType::LeftOuter, None, build) } -} - -impl QueryOperand for SetQueryBuilder { - type Source = Q; - type Model = M; - type Projection = TupleProjection; - type Shape = TupleProjection; - fn into_query_parts(self) -> (Self::Source, Query) { - (self.source, self.query) + pub fn left_join_as( + self, + alias: impl Into, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + self.join_on::(JoinType::LeftOuter, Some(alias.into()), build) } -} -impl QueryOperand - for SetQueryBuilder> -{ - type Source = Q; - type Model = M; - type Projection = StructProjection; - type Shape = T; - - fn into_query_parts(self) -> (Self::Source, Query) { - (self.source, self.query) + pub fn right_join( + self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + self.join_on::(JoinType::RightOuter, None, build) } -} - -fn ident(value: impl Into) -> Ident { - Ident::new(value) -} - -fn object_name(value: &str) -> ObjectName { - value.split('.').map(ident).collect::>().into() -} - -fn qualified_column_value(relation: &str, column: &str) -> QueryValue { - QueryValue::from_expr(Expr::CompoundIdentifier(vec![ - ident(relation), - ident(column), - ])) -} - -fn nested_expr(expr: Expr) -> Expr { - Expr::Nested(Box::new(expr)) -} - -fn number_expr(value: impl ToString) -> Expr { - Expr::Value(Value::Number(value.to_string(), false).with_empty_span()) -} -fn string_expr(value: impl Into) -> Expr { - Expr::Value(Value::SingleQuotedString(value.into()).with_empty_span()) -} - -fn describe_text_value(value: Option) -> String { - match value { - Some(DataValue::Utf8 { value, .. }) => value, - Some(other) => other.to_string(), - None => String::new(), + pub fn full_join( + self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + self.join_on::(JoinType::Full, None, build) } -} - -fn placeholder_expr(value: &str) -> Expr { - Expr::Value(Value::Placeholder(value.to_string()).with_empty_span()) -} - -fn typed_string_expr(data_type: DataType, value: impl Into) -> Expr { - Expr::TypedString(sqlparser::ast::TypedString { - data_type, - value: Value::SingleQuotedString(value.into()).with_empty_span(), - uses_odbc_syntax: false, - }) -} -fn column_option(option: ColumnOption) -> ColumnOptionDef { - ColumnOptionDef { name: None, option } -} - -fn table_factor(table_name: &str) -> TableFactor { - TableFactor::Table { - name: object_name(table_name), - alias: None, - args: None, - with_hints: vec![], - version: None, - with_ordinality: false, - partitions: vec![], - json_path: None, - sample: None, - index_hints: vec![], + pub fn cross_join(self) -> Result { + self.join_with::(JoinType::Cross, None, JoinConstraintInput::None) } -} -fn source_table_factor(source: &QuerySource) -> TableFactor { - TableFactor::Table { - name: object_name(&source.table_name), - alias: source.alias.as_ref().map(|alias| TableAlias { - explicit: true, - name: ident(alias), - columns: vec![], - }), - args: None, - with_hints: vec![], - version: None, - with_ordinality: false, - partitions: vec![], - json_path: None, - sample: None, - index_hints: vec![], + fn join_using( + self, + join_type: JoinType, + columns: impl IntoIterator>, + ) -> Result { + self.join_with::( + join_type, + None, + JoinConstraintInput::Using(columns.into_iter().map(Into::into).collect()), + ) } -} -fn table_with_joins(table_name: &str) -> TableWithJoins { - TableWithJoins { - relation: table_factor(table_name), - joins: vec![], + pub fn inner_join_using( + self, + columns: impl IntoIterator>, + ) -> Result { + self.join_using::(JoinType::Inner, columns) } -} -fn source_table_with_joins(source: &QuerySource, joins: Vec) -> TableWithJoins { - TableWithJoins { - relation: source_table_factor(source), - joins: joins.into_iter().map(JoinSpec::into_ast).collect(), + pub fn left_join_using( + self, + columns: impl IntoIterator>, + ) -> Result { + self.join_using::(JoinType::LeftOuter, columns) } -} - -fn validate_mutation_state( - state: &BuilderState, - kind: MutationKind, -) -> Result<(), DatabaseError> { - let operation = kind.as_str(); - - if !state.joins.is_empty() { - return Err(DatabaseError::UnsupportedStmt(format!( - "ORM {operation} builder does not support joins" - ))); - } - if state.distinct { - return Err(DatabaseError::UnsupportedStmt(format!( - "ORM {operation} builder does not support distinct" - ))); - } - if !state.group_bys.is_empty() { - return Err(DatabaseError::UnsupportedStmt(format!( - "ORM {operation} builder does not support group by" - ))); - } - if state.having.is_some() { - return Err(DatabaseError::UnsupportedStmt(format!( - "ORM {operation} builder does not support having" - ))); - } - if !state.order_bys.is_empty() { - return Err(DatabaseError::UnsupportedStmt(format!( - "ORM {operation} builder does not support order by" - ))); - } - if state.limit.is_some() { - return Err(DatabaseError::UnsupportedStmt(format!( - "ORM {operation} builder does not support limit" - ))); - } - if state.offset.is_some() { - return Err(DatabaseError::UnsupportedStmt(format!( - "ORM {operation} builder does not support offset" - ))); - } - - Ok(()) -} - -fn select_projection(fields: &[OrmField], relation: &str) -> Vec { - fields - .iter() - .map(|field| { - SelectItem::UnnamedExpr(qualified_column_value(relation, field.column).into_expr()) - }) - .collect() -} - -fn model_insert_columns() -> Vec { - M::fields() - .iter() - .map(|field| ident(field.column)) - .collect() -} -#[allow(clippy::too_many_arguments)] -fn select_query( - source: &QuerySource, - joins: Vec, - projection: Vec, - distinct: bool, - filter: Option, - group_bys: Vec, - having: Option, - order_bys: Vec, - limit: Option, - offset: Option, -) -> Query { - Query { - with: None, - body: Box::new(SetExpr::Select(Box::new(Select { - select_token: AttachedToken::empty(), - optimizer_hint: None, - distinct: distinct.then_some(Distinct::Distinct), - select_modifiers: None, - top: None, - top_before_distinct: false, - projection, - exclude: None, - into: None, - from: vec![source_table_with_joins(source, joins)], - lateral_views: vec![], - prewhere: None, - selection: filter.map(QueryExpr::into_expr), - connect_by: vec![], - group_by: GroupByExpr::Expressions( - group_bys.into_iter().map(QueryValue::into_expr).collect(), - vec![], - ), - cluster_by: vec![], - distribute_by: vec![], - sort_by: vec![], - having: having.map(QueryExpr::into_expr), - named_window: vec![], - qualify: None, - window_before_qualify: false, - value_table_mode: None, - flavor: SelectFlavor::Standard, - }))), - order_by: (!order_bys.is_empty()).then(|| OrderBy { - kind: OrderByKind::Expressions(order_bys.into_iter().map(SortExpr::into_ast).collect()), - interpolate: None, - }), - limit_clause: if limit.is_some() || offset.is_some() { - Some(LimitClause::LimitOffset { - limit: limit.map(number_expr), - offset: offset.map(|offset| Offset { - value: number_expr(offset), - rows: OffsetRows::None, - }), - limit_by: vec![], - }) - } else { - None - }, - fetch: None, - locks: vec![], - for_clause: None, - settings: None, - format_clause: None, - pipe_operators: vec![], + pub fn right_join_using( + self, + columns: impl IntoIterator>, + ) -> Result { + self.join_using::(JoinType::RightOuter, columns) } -} -fn set_operation_query( - left: Query, - right: Query, - op: SetOperator, - set_quantifier: SetQuantifier, -) -> Query { - Query { - with: None, - body: Box::new(SetExpr::SetOperation { - op, - set_quantifier, - left: Box::new(SetExpr::Query(Box::new(left))), - right: Box::new(SetExpr::Query(Box::new(right))), - }), - order_by: None, - limit_clause: None, - fetch: None, - locks: vec![], - for_clause: None, - settings: None, - format_clause: None, - pipe_operators: vec![], + pub fn full_join_using( + self, + columns: impl IntoIterator>, + ) -> Result { + self.join_using::(JoinType::Full, columns) } -} -fn set_query_quantifier(query: &mut Query, set_quantifier: SetQuantifier) { - if let SetExpr::SetOperation { - set_quantifier: current, - .. - } = query.body.as_mut() + pub fn project_model( + mut self, + ) -> Result, DatabaseError> { - *current = set_quantifier; + let relation = self.model_relation_name()?; + let mut select_list = Vec::with_capacity(M::fields().len()); + with_query_bind_step!(self.binder, QueryBindStep::Project, { + let scope = self.expr_scope(); + for field in M::fields() { + select_list.push( + scope + .qualified_column( + &relation, + Field::::new(M::table_name(), field.column), + )? + .into_orm_scalar(), + ); + } + })?; + Ok(self.select_list(select_list)) } -} -fn set_query_order_value(value: QueryValue) -> QueryValue { - match value.into_expr() { - Expr::CompoundIdentifier(mut parts) => QueryValue::from_expr(Expr::Identifier( - parts.pop().expect("compound identifier must not be empty"), - )), - expr => QueryValue::from_expr(expr), + pub fn project( + mut self, + ) -> Result, DatabaseError> + { + let relation = self.model_relation_name()?; + let projection = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = self.expr_scope(); + P::bind_projection(&mut scope, &relation)? + }); + Ok(self.select_list(projection?)) } -} -fn query_push_order(query: &mut Query, order: SortExpr) { - let order_expr = order.into_ast(); - match query.order_by.as_mut() { - Some(order_by) => match &mut order_by.kind { - OrderByKind::Expressions(exprs) => exprs.push(order_expr), - OrderByKind::All(_) => { - order_by.kind = OrderByKind::Expressions(vec![order_expr]); - } - }, - None => { - query.order_by = Some(OrderBy { - kind: OrderByKind::Expressions(vec![order_expr]), - interpolate: None, - }); - } + pub fn project_value( + mut self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result, DatabaseError> + where + E: IntoOrmScalarExpression, + { + let expr = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = self.expr_scope(); + build(&mut scope)?.into_orm_scalar() + }); + Ok(self.select_list(vec![expr?])) } -} -fn query_set_last_order_nulls(query: &mut Query, nulls_first: bool) { - if let Some(order_by) = query.order_by.as_mut() { - match &mut order_by.kind { - OrderByKind::Expressions(exprs) => { - if let Some(last) = exprs.last_mut() { - last.options.nulls_first = Some(nulls_first); - } - } - OrderByKind::All(_) => {} - } + pub fn project_tuple( + mut self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, DatabaseError>, + ) -> Result, DatabaseError> + where + E: IntoOrmScalarExpression, + { + let exprs = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = self.expr_scope(); + build(&mut scope)? + }); + Ok(self.select_list( + exprs? + .into_iter() + .map(IntoOrmScalarExpression::into_orm_scalar) + .collect(), + )) } -} - -fn query_set_limit(query: &mut Query, limit: usize) { - let offset = query_current_offset(query); - query.limit_clause = Some(LimitClause::LimitOffset { - limit: Some(number_expr(limit)), - offset, - limit_by: vec![], - }); -} - -fn query_set_offset(query: &mut Query, offset: usize) { - let limit = query_current_limit(query); - query.limit_clause = Some(LimitClause::LimitOffset { - limit, - offset: Some(Offset { - value: number_expr(offset), - rows: OffsetRows::None, - }), - limit_by: vec![], - }); -} -fn query_current_limit(query: &Query) -> Option { - match &query.limit_clause { - Some(LimitClause::LimitOffset { limit, .. }) => limit.clone(), - Some(LimitClause::OffsetCommaLimit { limit, .. }) => Some(limit.clone()), - None => None, + pub fn project_scalar( + mut self, + build: impl BindOrmScalar<'bind, 'parent, 'arena, T, A>, + ) -> Result, DatabaseError> + { + let expr = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = self.expr_scope(); + build.bind_scalar(&mut scope)? + }); + Ok(self.select_list(vec![expr?])) } -} -fn query_current_offset(query: &Query) -> Option { - match &query.limit_clause { - Some(LimitClause::LimitOffset { offset, .. }) => offset.clone(), - Some(LimitClause::OffsetCommaLimit { offset, .. }) => Some(Offset { - value: offset.clone(), - rows: OffsetRows::None, - }), - None => None, + pub fn project_scalars( + mut self, + build: impl BindOrmScalarList<'bind, 'parent, 'arena, T, A>, + ) -> Result, DatabaseError> + { + let exprs = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = self.expr_scope(); + build.bind_scalar_list(&mut scope)? + }); + Ok(self.select_list(exprs?)) } -} - -fn execute_query(source: Q, query: Query) -> Result { - source.execute_statement(&Statement::Query(Box::new(query)), &[]) -} - -fn execute_insert_query( - source: Q, - statement: Statement, -) -> Result<(), DatabaseError> { - source.execute_statement(&statement, &[])?.done() -} - -fn query_explain(source: Q, query: Query) -> Result { - let mut iter = source.execute_statement(&orm_explain_query_statement(query), &[])?; - let plan = match iter.next().transpose()? { - Some(tuple) => extract_value_from_tuple::(tuple)?, - None => { - return Err(DatabaseError::InvalidValue( - "EXPLAIN returned no plan rows".to_string(), - )) - } - }; - iter.done()?; - Ok(plan) -} - -fn orm_update_builder_statement( - source: &QuerySource, - filter: Option, - assignments: Vec, -) -> Statement { - Statement::Update(Update { - update_token: AttachedToken::empty(), - optimizer_hint: None, - table: source_table_with_joins(source, vec![]), - assignments, - from: None, - selection: filter.map(QueryExpr::into_expr), - returning: None, - or: None, - limit: None, - }) -} - -fn orm_insert_query_statement( - table_name: &str, - columns: Vec, - query: Query, - overwrite: bool, -) -> Statement { - Statement::Insert(Insert { - insert_token: AttachedToken::empty(), - optimizer_hint: None, - or: None, - ignore: false, - into: true, - table: TableObject::TableName(object_name(table_name)), - table_alias: None, - columns, - overwrite, - source: Some(Box::new(query)), - assignments: vec![], - partitioned: None, - after_columns: vec![], - has_table_keyword: false, - on: None, - returning: None, - replace_into: false, - priority: None, - insert_alias: None, - settings: None, - format_clause: None, - }) -} -fn empty_show_options() -> ShowStatementOptions { - ShowStatementOptions { - show_in: None, - starts_with: None, - limit: None, - limit_from: None, - filter_position: None, + pub fn group_by( + self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result, DatabaseError> + where + E: IntoOrmScalarExpression, + { + self.project_model()?.group_by(build) } -} -fn orm_show_tables_statement() -> Statement { - Statement::ShowTables { - terse: false, - history: false, - extended: false, - full: false, - external: false, - show_options: empty_show_options(), + pub fn having( + self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result, DatabaseError> + where + E: IntoOrmScalarExpression, + { + self.project_model()?.having(build) } -} -fn orm_show_views_statement() -> Statement { - Statement::ShowViews { - terse: false, - materialized: false, - show_options: empty_show_options(), + pub fn group_by_scalar( + self, + build: impl BindOrmScalar<'bind, 'parent, 'arena, T, A>, + ) -> Result, DatabaseError> + { + self.project_model()?.group_by_scalar(build) } -} -fn orm_describe_statement(table_name: &str) -> Statement { - Statement::ExplainTable { - describe_alias: DescribeAlias::Describe, - hive_format: None, - has_table_keyword: false, - table_name: object_name(table_name), + pub fn having_scalar( + self, + build: impl BindOrmScalar<'bind, 'parent, 'arena, T, A>, + ) -> Result, DatabaseError> + { + self.project_model()?.having_scalar(build) } -} -fn orm_explain_query_statement(query: Query) -> Statement { - Statement::Explain { - describe_alias: DescribeAlias::Explain, - analyze: false, - verbose: false, - query_plan: false, - estimate: false, - statement: Box::new(Statement::Query(Box::new(query))), - format: None, - options: None, + pub fn order_by( + self, + build: impl BindOrmSort<'bind, 'parent, 'arena, T, A>, + ) -> Result, DatabaseError> + { + self.project_model()?.order_by(build) } -} -fn orm_delete_builder_statement(source: &QuerySource, filter: Option) -> Statement { - Statement::Delete(Delete { - delete_token: AttachedToken::empty(), - optimizer_hint: None, - tables: vec![], - from: FromTable::WithFromKeyword(vec![source_table_with_joins(source, vec![])]), - using: None, - selection: filter.map(QueryExpr::into_expr), - returning: None, - order_by: vec![], - limit: None, - }) -} + pub fn order_by_expr( + self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result, DatabaseError> + { + self.project_model()?.order_by_expr(build) + } -fn query_exists(source: Q, mut query: Query) -> Result { - query_set_limit(&mut query, 1); - let mut iter = execute_query(source, query)?; - Ok(iter.next().transpose()?.is_some()) -} + pub fn count(mut self) -> Result { + let count = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let scope = self.expr_scope(); + let count = scope.count_all()?; + scope.alias(count, "count").into_orm_scalar() + }); + self.select_list(vec![count?]).count() + } -fn query_count(source: Q, query: Query) -> Result { - let mut iter = execute_query(source, query)?; - let mut count = 0usize; - while iter.next().transpose()?.is_some() { - count += 1; + pub fn exists(self) -> Result { + self.binder.bind_limit_values(self.plan, None, Some(1)) } - iter.done()?; - Ok(count) -} -fn values_query(values: Vec) -> Query { - Query { - with: None, - body: Box::new(SetExpr::Values(Values { - explicit_row: false, - value_keyword: false, - rows: vec![values], - })), - order_by: None, - limit_clause: None, - fetch: None, - locks: vec![], - for_clause: None, - settings: None, - format_clause: None, - pipe_operators: vec![], + pub fn delete(self) -> Result { + let table_name = self.model_table_name()?; + let primary_keys = self + .binder + .context + .table(table_name.clone())? + .ok_or(DatabaseError::TableNotFound)? + .primary_keys() + .iter() + .map(|(_, column)| *column) + .collect(); + self.binder.with_pk(table_name.clone()); + self.binder.bind_delete(table_name, primary_keys, self.plan) } -} -fn parse_expr_fragment(value: &str) -> Result { - let dialect = PostgreSqlDialect {}; - let mut parser = Parser::new(&dialect).try_with_sql(value)?; - parser.parse_expr().map_err(Into::into) -} + pub fn update( + self, + build: impl FnOnce( + &mut UpdateBindScope<'scope_ctx, 'bind, 'parent, 'arena, T, A>, + ) -> Result<(), DatabaseError>, + ) -> Result { + let table_name = self.model_table_name()?; + let source_name = self.model_relation_name()?; + self.binder.context.allow_default = true; + self.binder.with_pk(table_name.clone()); + let mut scope = UpdateBindScope { + binder: self.binder, + arena: self.arena, + source_name, + value_exprs: Vec::new(), + }; + build(&mut scope)?; + scope.finish(table_name, self.plan) + } -fn parse_data_type_fragment(value: &str) -> Result { - let dialect = PostgreSqlDialect {}; - let mut parser = Parser::new(&dialect).try_with_sql(value)?; - parser.parse_data_type().map_err(Into::into) + pub fn finish(self) -> Result { + self.project_model()?.finish() + } } -fn data_value_to_ast_expr(value: &DataValue) -> Expr { - match value { - DataValue::Null => Expr::Value(Value::Null.with_empty_span()), - DataValue::Boolean(value) => Expr::Value(Value::Boolean(*value).with_empty_span()), - DataValue::Float32(value) => number_expr(value), - DataValue::Float64(value) => number_expr(value), - DataValue::Int8(value) => number_expr(value), - DataValue::Int16(value) => number_expr(value), - DataValue::Int32(value) => number_expr(value), - DataValue::Int64(value) => number_expr(value), - DataValue::UInt8(value) => number_expr(value), - DataValue::UInt16(value) => number_expr(value), - DataValue::UInt32(value) => number_expr(value), - DataValue::UInt64(value) => number_expr(value), - DataValue::Utf8 { value, .. } => string_expr(value), - DataValue::Date32(_) => typed_string_expr(DataType::Date, value.to_string()), - DataValue::Date64(_) => typed_string_expr(DataType::Datetime(None), value.to_string()), - DataValue::Time32(..) => { - typed_string_expr(DataType::Time(None, TimezoneInfo::None), value.to_string()) - } - DataValue::Time64(_, _, zone) => typed_string_expr( - DataType::Timestamp( - None, - if *zone { - TimezoneInfo::WithTimeZone - } else { - TimezoneInfo::None - }, - ), - value.to_string(), - ), - DataValue::Decimal(value) => number_expr(value), - DataValue::Tuple(values, ..) => { - Expr::Tuple(values.iter().map(data_value_to_ast_expr).collect()) +impl<'scope_ctx, 'bind, 'parent, 'arena, T, A, M> + BindPlanSelectList<'scope_ctx, 'bind, 'parent, 'arena, T, A, M> +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, + M: Model, +{ + fn expr_scope<'scope>(&'scope mut self) -> ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A> { + ExprBindScope { + binder: self.binder, + arena: self.arena, } } -} -#[doc(hidden)] -pub fn orm_select_statement(table_name: &str, fields: &[OrmField]) -> Statement { - Statement::Query(Box::new(select_query( - &QuerySource { - table_name: table_name.to_string(), - alias: None, - }, - vec![], - select_projection(fields, table_name), - false, - None, - vec![], - None, - vec![], - None, - None, - ))) -} - -#[doc(hidden)] -pub fn orm_insert_statement(table_name: &str, fields: &[OrmField]) -> Statement { - orm_insert_values_statement(table_name, fields, false) -} - -#[doc(hidden)] -pub fn orm_overwrite_statement(table_name: &str, fields: &[OrmField]) -> Statement { - orm_insert_values_statement(table_name, fields, true) -} + pub fn project_value( + mut self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + let expr = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = self.expr_scope(); + build(&mut scope)?.into_orm_scalar() + }); + Ok(self.set_select_list(vec![expr?])) + } -fn orm_insert_values_statement( - table_name: &str, - fields: &[OrmField], - overwrite: bool, -) -> Statement { - Statement::Insert(Insert { - insert_token: AttachedToken::empty(), - optimizer_hint: None, - or: None, - ignore: false, - into: true, - table: TableObject::TableName(object_name(table_name)), - table_alias: None, - columns: fields.iter().map(|field| ident(field.column)).collect(), - overwrite, - source: Some(Box::new(values_query( - fields - .iter() - .map(|field| placeholder_expr(field.placeholder)) + pub fn project_tuple( + mut self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, DatabaseError>, + ) -> Result + where + E: IntoOrmScalarExpression, + { + let exprs = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = self.expr_scope(); + build(&mut scope)? + }); + Ok(self.set_select_list( + exprs? + .into_iter() + .map(IntoOrmScalarExpression::into_orm_scalar) .collect(), - ))), - assignments: vec![], - partitioned: None, - after_columns: vec![], - has_table_keyword: false, - on: None, - returning: None, - replace_into: false, - priority: None, - insert_alias: None, - settings: None, - format_clause: None, - }) -} + )) + } -#[doc(hidden)] -pub fn orm_truncate_statement(table_name: &str) -> Statement { - Statement::Truncate(Truncate { - table_names: vec![TruncateTableTarget { - name: object_name(table_name), - only: false, - has_asterisk: false, - }], - partitions: None, - table: true, - if_exists: false, - identity: None, - cascade: None, - on_cluster: None, - }) -} + pub fn project_scalar( + mut self, + build: impl BindOrmScalar<'bind, 'parent, 'arena, T, A>, + ) -> Result { + let expr = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = self.expr_scope(); + build.bind_scalar(&mut scope)? + }); + Ok(self.set_select_list(vec![expr?])) + } -#[doc(hidden)] -pub fn orm_create_view_statement(view_name: &str, query: Query, or_replace: bool) -> Statement { - Statement::CreateView(CreateView { - or_alter: false, - or_replace, - materialized: false, - secure: false, - name: object_name(view_name), - name_before_not_exists: false, - columns: Vec::::new(), - query: Box::new(query), - options: CreateTableOptions::None, - cluster_by: vec![], - comment: None, - with_no_schema_binding: false, - if_not_exists: false, - temporary: false, - to: None, - params: None, - }) -} + pub fn project_scalars( + mut self, + build: impl BindOrmScalarList<'bind, 'parent, 'arena, T, A>, + ) -> Result { + let exprs = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let mut scope = self.expr_scope(); + build.bind_scalar_list(&mut scope)? + }); + Ok(self.set_select_list(exprs?)) + } -#[doc(hidden)] -pub fn orm_drop_view_statement(view_name: &str, if_exists: bool) -> Statement { - Statement::Drop { - object_type: ObjectType::View, - if_exists, - names: vec![object_name(view_name)], - cascade: false, - restrict: false, - purge: false, - temporary: false, - table: None, + pub fn group_by( + mut self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + let expr = with_query_bind_step!(self.binder, QueryBindStep::Agg, { + let mut scope = self.expr_scope(); + build(&mut scope)?.into_orm_scalar() + }); + self.group_by_expr(expr?) } -} -#[doc(hidden)] -pub fn orm_find_statement( - table_name: &str, - fields: &[OrmField], - primary_key: &OrmField, -) -> Statement { - Statement::Query(Box::new(Query { - with: None, - body: Box::new(SetExpr::Select(Box::new(Select { - select_token: AttachedToken::empty(), - optimizer_hint: None, - distinct: None, - select_modifiers: None, - top: None, - top_before_distinct: false, - projection: select_projection(fields, table_name), - exclude: None, - into: None, - from: vec![table_with_joins(table_name)], - lateral_views: vec![], - prewhere: None, - selection: Some(Expr::BinaryOp { - left: Box::new(Expr::Identifier(ident(primary_key.column))), - op: SqlBinaryOperator::Eq, - right: Box::new(placeholder_expr(primary_key.placeholder)), - }), - connect_by: vec![], - group_by: GroupByExpr::Expressions(vec![], vec![]), - cluster_by: vec![], - distribute_by: vec![], - sort_by: vec![], - having: None, - named_window: vec![], - qualify: None, - window_before_qualify: false, - value_table_mode: None, - flavor: SelectFlavor::Standard, - }))), - order_by: None, - limit_clause: None, - fetch: None, - locks: vec![], - for_clause: None, - settings: None, - format_clause: None, - pipe_operators: vec![], - })) -} + pub fn having( + mut self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result + where + E: IntoOrmScalarExpression, + { + let expr = with_query_bind_step!(self.binder, QueryBindStep::Having, { + let mut scope = self.expr_scope(); + build(&mut scope)?.into_orm_scalar() + }); + self.having_expr(expr?) + } -#[doc(hidden)] -pub fn orm_create_table_statement( - table_name: &str, - columns: &[OrmColumn], - if_not_exists: bool, -) -> Result { - Ok(Statement::CreateTable(CreateTable { - or_replace: false, - temporary: false, - external: false, - dynamic: false, - global: None, - if_not_exists, - transient: false, - volatile: false, - iceberg: false, - name: object_name(table_name), - columns: columns - .iter() - .map(OrmColumn::column_def) - .collect::, _>>()?, - constraints: vec![], - hive_distribution: HiveDistributionStyle::NONE, - hive_formats: None, - table_options: Default::default(), - file_format: None, - location: None, - query: None, - without_rowid: false, - like: None, - clone: None, - version: None, - comment: None, - on_commit: None, - on_cluster: None, - primary_key: None, - order_by: None, - partition_by: None, - cluster_by: None, - clustered_by: None, - inherits: None, - partition_of: None, - for_values: None, - strict: false, - copy_grants: false, - enable_schema_evolution: None, - change_tracking: None, - data_retention_time_in_days: None, - max_data_extension_time_in_days: None, - default_ddl_collation: None, - with_aggregation_policy: None, - with_row_access_policy: None, - with_tags: None, - external_volume: None, - base_location: None, - catalog: None, - catalog_sync: None, - storage_serialization_policy: None, - target_lag: None, - warehouse: None, - refresh_mode: None, - initialize: None, - require_user: false, - })) -} + pub fn group_by_scalar( + mut self, + build: impl BindOrmScalar<'bind, 'parent, 'arena, T, A>, + ) -> Result { + let expr = with_query_bind_step!(self.binder, QueryBindStep::Agg, { + let mut scope = self.expr_scope(); + build.bind_scalar(&mut scope)? + }); + self.group_by_expr(expr?) + } -#[doc(hidden)] -pub fn orm_create_index_statement( - table_name: &str, - index_name: &str, - columns: &[&str], - unique: bool, - if_not_exists: bool, -) -> Statement { - Statement::CreateIndex(CreateIndex { - name: Some(object_name(index_name)), - table_name: object_name(table_name), - using: None, - columns: columns.iter().copied().map(IndexColumn::from).collect(), - unique, - concurrently: false, - if_not_exists, - include: vec![], - nulls_distinct: None, - with: vec![], - predicate: None, - index_options: vec![], - alter_options: vec![], - }) -} + pub fn having_scalar( + mut self, + build: impl BindOrmScalar<'bind, 'parent, 'arena, T, A>, + ) -> Result { + let expr = with_query_bind_step!(self.binder, QueryBindStep::Having, { + let mut scope = self.expr_scope(); + build.bind_scalar(&mut scope)? + }); + self.having_expr(expr?) + } -#[doc(hidden)] -pub fn orm_drop_table_statement(table_name: &str, if_exists: bool) -> Statement { - Statement::Drop { - object_type: ObjectType::Table, - if_exists, - names: vec![object_name(table_name)], - cascade: false, - restrict: false, - purge: false, - temporary: false, - table: None, + pub fn order_by( + mut self, + build: impl BindOrmSort<'bind, 'parent, 'arena, T, A>, + ) -> Result { + let sort = with_query_bind_step!(self.binder, QueryBindStep::Sort, { + let mut scope = self.expr_scope(); + build.bind_sort(&mut scope)? + }); + self.sort_field(sort?) } -} -#[doc(hidden)] -pub fn orm_drop_index_statement(table_name: &str, index_name: &str, if_exists: bool) -> Statement { - Statement::Drop { - object_type: ObjectType::Index, - if_exists, - names: vec![object_name(&format!("{table_name}.{index_name}"))], - cascade: false, - restrict: false, - purge: false, - temporary: false, - table: None, + pub fn order_by_expr( + mut self, + build: impl for<'scope> FnOnce( + &'scope mut ExprBindScope<'scope, 'bind, 'parent, 'arena, T, A>, + ) -> Result, + ) -> Result { + let sort = with_query_bind_step!(self.binder, QueryBindStep::Sort, { + let mut scope = self.expr_scope(); + build(&mut scope)? + }); + self.sort_field(sort?) } -} -#[doc(hidden)] -pub fn orm_analyze_statement(table_name: &str) -> Statement { - Statement::Analyze(Analyze { - table_name: Some(object_name(table_name)), - partitions: None, - for_columns: false, - columns: vec![], - cache_metadata: false, - noscan: false, - compute_statistics: false, - has_table_keyword: true, - }) + pub fn count(mut self) -> Result { + let count = with_query_bind_step!(self.binder, QueryBindStep::Project, { + let scope = self.expr_scope(); + let count = scope.count_all()?; + scope.alias(count, "count").into_orm_scalar() + }); + self.set_select_list(vec![count?]) + .aggregate_without_group()? + .finish() + } } -fn orm_count_statement( - source: &QuerySource, - joins: Vec, - filter: Option, -) -> Statement { - Statement::Query(Box::new(select_query( - source, - joins, - vec![SelectItem::UnnamedExpr(Expr::Function(Function { - name: object_name("count"), - uses_odbc_syntax: false, - parameters: FunctionArguments::None, - args: FunctionArguments::List(FunctionArgumentList { - duplicate_treatment: None, - args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard)], - clauses: vec![], - }), - filter: None, - null_treatment: None, - over: None, - within_group: vec![], - }))], - false, - filter, - vec![], - None, - vec![], - None, - None, - ))) +#[doc(hidden)] +pub trait Projection: + for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)> +{ + fn bind_projection<'ctx, 'bind, 'parent, 'arena, T, A>( + scope: &mut ExprBindScope<'ctx, 'bind, 'parent, 'arena, T, A>, + relation: &str, + ) -> Result, DatabaseError> + where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>; } -fn orm_alter_table_statement(table_name: &str, operation: AlterTableOperation) -> Statement { - Statement::AlterTable(AlterTable { - name: object_name(table_name), - if_exists: false, - only: false, - operations: vec![operation], - location: None, - on_cluster: None, - table_type: None, - end_token: AttachedToken::empty(), +fn orm_table_alias(source: &QuerySource) -> Option { + source.alias.as_ref().map(|alias| TableAliasInput { + name: alias.as_str().into(), + columns: Vec::new(), }) } -fn orm_alter_column_type_statement( - table_name: &str, - column_name: &str, - ddl_type: &str, -) -> Result { - Ok(orm_alter_table_statement( - table_name, - AlterTableOperation::AlterColumn { - column_name: ident(column_name), - op: AlterColumnOperation::SetDataType { - data_type: parse_data_type_fragment(ddl_type)?, - using: None, - had_set: false, - }, - }, - )) +fn bind_orm_source<'bind, 'parent, 'arena, T, A>( + binder: &mut Binder<'bind, 'parent, T, A>, + source: QuerySource, + join_type: Option, + arena: &mut PlanArena<'arena>, +) -> Result +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + let alias = orm_table_alias(&source); + binder.bind_base_table_ref(join_type, source.table_name.as_str().into(), alias, arena) } -fn orm_alter_column_default_statement( - table_name: &str, +fn bind_orm_target_column<'bind, 'parent, 'arena, T, A>( + binder: &mut Binder<'bind, 'parent, T, A>, + source_name: &str, column_name: &str, - default_expr: Option<&str>, -) -> Result { - Ok(orm_alter_table_statement( - table_name, - AlterTableOperation::AlterColumn { - column_name: ident(column_name), - op: match default_expr { - Some(default_expr) => AlterColumnOperation::SetDefault { - value: parse_expr_fragment(default_expr)?, - }, - None => AlterColumnOperation::DropDefault, - }, - }, - )) + arena: &mut PlanArena<'arena>, +) -> Result +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + match binder.bind_column_ref_by_name(None, column_name, Some(source_name), arena)? { + ScalarExpression::ColumnRef { column, .. } => Ok(column), + _ => Err(DatabaseError::invalid_column(column_name.to_string())), + } } -fn orm_alter_column_nullability_statement( +fn bind_orm_insert_plan<'bind, 'parent, 'arena, T, A>( + binder: &mut Binder<'bind, 'parent, T, A>, table_name: &str, - column_name: &str, - nullable: bool, -) -> Statement { - orm_alter_table_statement( - table_name, - AlterTableOperation::AlterColumn { - column_name: ident(column_name), - op: if nullable { - AlterColumnOperation::DropNotNull - } else { - AlterColumnOperation::SetNotNull - }, - }, - ) -} + columns: Vec, + mut input_plan: LogicalPlan, + overwrite: bool, + arena: &mut PlanArena<'arena>, +) -> Result +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, +{ + let table_name: TableName = table_name.into(); + let input_schema = input_plan.output_schema(arena).clone(); + let input_len = input_schema.len(); + + let projection = { + let source = binder + .context + .source(&table_name)? + .ok_or(DatabaseError::TableNotFound)?; + + if columns.is_empty() { + let table_schema = source.schema(); + if input_len > table_schema.len() { + return Err(DatabaseError::ValuesLenMismatch( + table_schema.len(), + input_len, + )); + } + table_schema[..input_len] + .iter() + .copied() + .enumerate() + .map(|(position, target_column)| ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr( + input_schema[position], + position, + )), + alias: AliasType::Name(arena.column(target_column).name().to_string()), + }) + .collect::>() + } else { + if input_len != columns.len() { + return Err(DatabaseError::ValuesLenMismatch(columns.len(), input_len)); + } + let mut projection = Vec::with_capacity(columns.len()); + for (position, column_name) in columns.into_iter().enumerate() { + let column = source + .column(&column_name, arena) + .ok_or_else(|| DatabaseError::column_not_found(column_name.clone()))?; + projection.push(ScalarExpression::Alias { + expr: Box::new(ScalarExpression::column_expr( + input_schema[position], + position, + )), + alias: AliasType::Name(arena.column(column).name().to_string()), + }); + } + projection + } + }; + input_plan = binder.bind_project(input_plan, projection, arena)?; -fn orm_rename_column_statement(table_name: &str, old_name: &str, new_name: &str) -> Statement { - orm_alter_table_statement( - table_name, - AlterTableOperation::RenameColumn { - old_column_name: ident(old_name), - new_column_name: ident(new_name), - }, - ) + binder.bind_insert_query(table_name, input_plan, overwrite) } -fn orm_drop_column_statement(table_name: &str, column_name: &str) -> Statement { - orm_alter_table_statement( - table_name, - AlterTableOperation::DropColumn { - has_column_keyword: true, - column_names: vec![ident(column_name)], - if_exists: false, - drop_behavior: None, - }, - ) +fn bind_orm_insert_model<'bind, 'parent, 'arena, T, A, M>( + binder: &mut Binder<'bind, 'parent, T, A>, + params: Vec<(&'static str, DataValue)>, + arena: &mut PlanArena<'arena>, +) -> Result +where + T: Transaction, + A: AsRef<[(&'static str, DataValue)]>, + M: Model, +{ + let table_name: TableName = M::table_name().into(); + let source = binder + .context + .source_and_bind(table_name.clone(), None, None, false)? + .ok_or(DatabaseError::TableNotFound)?; + let params = params.into_iter().collect::>(); + let mut schema_ref = Vec::with_capacity(M::fields().len()); + let mut row = Vec::with_capacity(M::fields().len()); + + for field in M::fields() { + let column = source + .column(field.column, arena) + .ok_or_else(|| DatabaseError::column_not_found(field.column.to_string()))?; + let column_catalog = arena.column(column); + let value = params + .get(field.placeholder) + .ok_or_else(|| DatabaseError::parameter_not_found(field.placeholder))? + .clone() + .cast(column_catalog.datatype())?; + value.check_len(column_catalog.datatype())?; + if matches!(value, DataValue::Null) && !column_catalog.nullable() { + return Err(DatabaseError::not_null_column( + column_catalog.name().to_string(), + )); + } + schema_ref.push(column); + row.push(value); + } + + binder.bind_insert_values(table_name, schema_ref, vec![row], false, true) } -fn orm_add_column_statement( - table_name: &str, - column: &OrmColumn, -) -> Result { - Ok(orm_alter_table_statement( - table_name, - AlterTableOperation::AddColumn { - column_keyword: true, - if_not_exists: false, - column_def: column.column_def()?, - column_position: None, - }, - )) +fn describe_text_value(value: Option) -> String { + match value { + Some(DataValue::Utf8 { value, .. }) => value, + Some(other) => other.to_string(), + None => String::new(), + } } /// Trait implemented by ORM models. /// /// In normal usage you should derive this trait with `#[derive(Model)]` rather -/// than implementing it by hand. The derive macro generates tuple mapping, -/// cached model/DDL statements and model metadata. -pub trait Model: Sized + for<'a> From<(&'a SchemaRef, Tuple)> { +/// than implementing it by hand. The derive macro generates tuple mapping and +/// model metadata. +pub trait Model: + Sized + for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)> +{ /// Rust type used as the model primary key. /// /// This associated type lets APIs such as @@ -4959,11 +2466,16 @@ pub trait Model: Sized + for<'a> From<(&'a SchemaRef, Tuple)> { /// Returns metadata for every persisted field on the model. fn fields() -> &'static [OrmField]; - /// Returns persisted column definitions for the model. + /// Returns persisted column catalogs for the model. /// /// `#[derive(Model)]` generates this automatically. Manual implementations /// can override it to opt into [`Database::migrate`](crate::orm::Database::migrate). - fn columns() -> &'static [OrmColumn] { + fn columns() -> &'static [ColumnCatalog] { + &[] + } + + /// Returns secondary indexes declared by the model. + fn indexes() -> &'static [(&'static str, &'static [&'static str], bool)] { &[] } @@ -4973,44 +2485,6 @@ pub trait Model: Sized + for<'a> From<(&'a SchemaRef, Tuple)> { /// Returns a reference to the current primary-key value. fn primary_key(&self) -> &Self::PrimaryKey; - /// Returns the cached `SELECT` statement used by [`Database::fetch`](crate::orm::Database::fetch). - fn select_statement() -> &'static Statement; - - /// Returns the cached `INSERT` statement for the model. - fn insert_statement() -> &'static Statement; - - /// Returns the cached `SELECT .. WHERE primary_key = ...` statement. - fn find_statement() -> &'static Statement; - - /// Returns the cached `CREATE TABLE` statement for the model. - fn create_table_statement() -> &'static Statement; - - /// Returns the cached `CREATE TABLE IF NOT EXISTS` statement for the model. - fn create_table_if_not_exists_statement() -> &'static Statement; - - /// Returns cached `CREATE INDEX` statements declared by the model. - /// - /// `#[derive(Model)]` generates these from fields annotated with - /// `#[model(index)]`. Manual implementations can override this to provide - /// custom secondary indexes. - fn create_index_statements() -> &'static [Statement] { - &[] - } - - /// Returns cached `CREATE INDEX IF NOT EXISTS` statements declared by the model. - fn create_index_if_not_exists_statements() -> &'static [Statement] { - &[] - } - - /// Returns the cached `DROP TABLE` statement for the model. - fn drop_table_statement() -> &'static Statement; - - /// Returns the cached `DROP TABLE IF EXISTS` statement for the model. - fn drop_table_if_exists_statement() -> &'static Statement; - - /// Returns the cached `ANALYZE TABLE` statement for the model. - fn analyze_statement() -> &'static Statement; - /// Returns metadata for the primary-key field. fn primary_key_field() -> &'static OrmField { Self::fields() @@ -5023,16 +2497,15 @@ pub trait Model: Sized + for<'a> From<(&'a SchemaRef, Tuple)> { /// Conversion trait from [`DataValue`] into Rust values for ORM mapping. /// /// This trait is mainly intended for framework internals and derive-generated -/// code, but it also powers scalar projections such as [`FromBuilder::project_value`]. +/// code, but it also powers scalar projections decoded from binder-backed ORM plans. /// /// Built-in scalar types already implement this trait, so most users only need /// to pick the target type when decoding: /// /// ```rust,ignore /// let ids = database -/// .from::() -/// .project_value(User::id()) -/// .fetch::()?; +/// .bind(|ctx| ctx.from::()?.project_scalar(User::id()))? +/// .project_value::(); /// # Ok::<(), kite_sql::errors::DatabaseError>(()) /// ``` pub trait FromDataValue: Sized { @@ -5049,9 +2522,8 @@ pub trait FromDataValue: Sized { /// /// ```rust,ignore /// let rows = database -/// .from::() -/// .project_tuple((User::id(), User::name())) -/// .fetch::<(i32, String)>()?; +/// .bind(|ctx| ctx.from::()?.project_scalars((User::id(), User::name())))? +/// .project_tuple::<(i32, String)>(); /// # Ok::<(), kite_sql::errors::DatabaseError>(()) /// ``` pub trait FromQueryTuple: Sized { @@ -5061,13 +2533,12 @@ pub trait FromQueryTuple: Sized { /// Typed adapter over a [`ResultIter`] that yields projected values instead of raw tuples. /// -/// This is returned by [`FromBuilder::project_value`] followed by `fetch::()`. +/// This adapts a raw ORM result iterator into scalar projected values. /// /// ```rust,ignore /// let mut ids = database -/// .from::() -/// .project_value(User::id()) -/// .fetch::()?; +/// .bind(|ctx| ctx.from::()?.project_scalar(User::id()))? +/// .project_value::(); /// /// let first = ids.next().transpose()?; /// ids.done()?; @@ -5079,6 +2550,19 @@ pub struct ProjectValueIter { _marker: PhantomData, } +/// Convenience adapters for raw result iterators produced by binder-backed ORM plans. +pub trait OrmQueryResultExt: ResultIter + Sized { + fn project_value(self) -> ProjectValueIter { + ProjectValueIter::new(self) + } + + fn project_tuple(self) -> ProjectTupleIter { + ProjectTupleIter::new(self) + } +} + +impl OrmQueryResultExt for I {} + impl ProjectValueIter where I: ResultIter, @@ -5102,13 +2586,12 @@ where /// Typed adapter over a [`ResultIter`] that yields projected tuples. /// -/// This is returned by [`FromBuilder::project_tuple`] followed by `fetch::()`. +/// This adapts a raw ORM result iterator into tuple projected rows. /// /// ```rust,ignore /// let mut rows = database -/// .from::() -/// .project_tuple((User::id(), User::name())) -/// .fetch::<(i32, String)>()?; +/// .bind(|ctx| ctx.from::()?.project_scalars((User::id(), User::name())))? +/// .project_tuple::<(i32, String)>(); /// /// let first = rows.next().transpose()?; /// rows.done()?; @@ -5190,8 +2673,8 @@ pub trait ToDataValue { /// This trait only affects ORM-generated DDL. Query decoding still goes through /// [`FromDataValue`], and bound parameters still go through [`ToDataValue`]. pub trait ModelColumnType { - /// Returns the SQL type name used in ORM-generated DDL. - fn ddl_type() -> String; + /// Returns the core logical type used in ORM-generated DDL. + fn logical_type() -> LogicalType; /// Whether this field type maps to a nullable SQL column. fn nullable() -> bool { @@ -5217,14 +2700,11 @@ pub trait DecimalType {} #[doc(hidden)] pub fn try_get( tuple: &mut Tuple, - schema: &SchemaRef, + schema: &SchemaView<'_, '_>, field_name: &str, ) -> Option { let ty = T::logical_type()?; - let (idx, _) = schema - .iter() - .enumerate() - .find(|(_, col)| col.name() == field_name)?; + let idx = schema.position(field_name)?; let value = std::mem::replace(&mut tuple.values[idx], DataValue::Null) .cast(&ty) @@ -5270,46 +2750,79 @@ impl_from_data_value_by_method!(u32, u32); impl_from_data_value_by_method!(u64, u64); impl_from_data_value_by_method!(f32, float); impl_from_data_value_by_method!(f64, double); -impl_from_data_value_by_method!(NaiveDate, date); -impl_from_data_value_by_method!(NaiveDateTime, datetime); -impl_from_data_value_by_method!(NaiveTime, time); +#[cfg(feature = "decimal")] impl_from_data_value_by_method!(Decimal, decimal); -impl_to_data_value_by_clone!(bool, i8, i16, i32, i64, u8, u16, u32, u64, f32, f64, Decimal, String); +impl_to_data_value_by_clone!(bool, i8, i16, i32, i64, u8, u16, u32, u64, f32, f64, String); +#[cfg(feature = "decimal")] +impl_to_data_value_by_clone!(Decimal); macro_rules! impl_model_column_type { - ($sql:expr; $($ty:ty),+ $(,)?) => { + ($logical_type:expr; $($ty:ty),+ $(,)?) => { $( impl ModelColumnType for $ty { - fn ddl_type() -> String { - $sql.to_string() + fn logical_type() -> LogicalType { + $logical_type } } )+ }; } -impl_model_column_type!("boolean"; bool); -impl_model_column_type!("tinyint"; i8); -impl_model_column_type!("smallint"; i16); -impl_model_column_type!("int"; i32); -impl_model_column_type!("bigint"; i64); -impl_model_column_type!("utinyint"; u8); -impl_model_column_type!("usmallint"; u16); -impl_model_column_type!("unsigned integer"; u32); -impl_model_column_type!("ubigint"; u64); -impl_model_column_type!("float"; f32); -impl_model_column_type!("double"; f64); -impl_model_column_type!("date"; NaiveDate); -impl_model_column_type!("datetime"; NaiveDateTime); -impl_model_column_type!("time"; NaiveTime); -impl_model_column_type!("decimal"; Decimal); -impl_model_column_type!("varchar"; String, Arc); +impl_model_column_type!(LogicalType::Boolean; bool); +impl_model_column_type!(LogicalType::Tinyint; i8); +impl_model_column_type!(LogicalType::Smallint; i16); +impl_model_column_type!(LogicalType::Integer; i32); +impl_model_column_type!(LogicalType::Bigint; i64); +impl_model_column_type!(LogicalType::UTinyint; u8); +impl_model_column_type!(LogicalType::USmallint; u16); +impl_model_column_type!(LogicalType::UInteger; u32); +impl_model_column_type!(LogicalType::UBigint; u64); +impl_model_column_type!(LogicalType::Float; f32); +impl_model_column_type!(LogicalType::Double; f64); +#[cfg(feature = "decimal")] +impl_model_column_type!(LogicalType::Decimal(None, None); Decimal); +impl_model_column_type!(LogicalType::Varchar(None, CharLengthUnits::Characters); String, Arc); impl StringType for String {} impl StringType for Arc {} +#[cfg(feature = "decimal")] impl DecimalType for Decimal {} +#[cfg(feature = "time")] +mod chrono_orm { + use super::{FromDataValue, ModelColumnType, ToDataValue}; + use crate::types::value::DataValue; + use crate::types::LogicalType; + use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; + + impl_from_data_value_by_method!(NaiveDate, date); + impl_from_data_value_by_method!(NaiveDateTime, datetime); + impl_from_data_value_by_method!(NaiveTime, time); + + impl_model_column_type!(LogicalType::Date; NaiveDate); + impl_model_column_type!(LogicalType::DateTime; NaiveDateTime); + impl_model_column_type!(LogicalType::Time(Some(0)); NaiveTime); + + impl ToDataValue for NaiveDate { + fn to_data_value(&self) -> DataValue { + DataValue::from(self) + } + } + + impl ToDataValue for NaiveDateTime { + fn to_data_value(&self) -> DataValue { + DataValue::from(self) + } + } + + impl ToDataValue for NaiveTime { + fn to_data_value(&self) -> DataValue { + DataValue::from(self) + } + } +} + impl FromDataValue for String { fn logical_type() -> Option { LogicalType::type_trans::() @@ -5356,24 +2869,6 @@ impl ToDataValue for &str { } } -impl ToDataValue for NaiveDate { - fn to_data_value(&self) -> DataValue { - DataValue::from(self) - } -} - -impl ToDataValue for NaiveDateTime { - fn to_data_value(&self) -> DataValue { - DataValue::from(self) - } -} - -impl ToDataValue for NaiveTime { - fn to_data_value(&self) -> DataValue { - DataValue::from(self) - } -} - impl FromDataValue for Option { fn logical_type() -> Option { T::logical_type() @@ -5398,8 +2893,8 @@ impl ToDataValue for Option { } impl ModelColumnType for Option { - fn ddl_type() -> String { - T::ddl_type() + fn logical_type() -> LogicalType { + T::logical_type() } fn nullable() -> bool { @@ -5453,96 +2948,38 @@ impl_from_query_tuple!( (A, B, C, D, E, F, G, H), ); -fn normalize_sql_fragment(value: &str) -> String { - value - .split_whitespace() - .collect::>() - .join(" ") - .to_ascii_lowercase() -} - -fn canonicalize_model_type(value: &str) -> String { - let normalized = normalize_sql_fragment(value); - - match normalized.as_str() { - "boolean" => "boolean".to_string(), - "tinyint" => "tinyint".to_string(), - "smallint" => "smallint".to_string(), - "int" | "integer" => "integer".to_string(), - "bigint" => "bigint".to_string(), - "utinyint" => "utinyint".to_string(), - "usmallint" => "usmallint".to_string(), - "unsigned integer" | "uinteger" => "uinteger".to_string(), - "ubigint" => "ubigint".to_string(), - "float" => "float".to_string(), - "double" => "double".to_string(), - "date" => "date".to_string(), - "datetime" => "datetime".to_string(), - "time" => "time(some(0))".to_string(), - "varchar" => "varchar(none, characters)".to_string(), - "decimal" => "decimal(none, none)".to_string(), - _ => { - if let Some(inner) = normalized - .strip_prefix("varchar(") - .and_then(|value| value.strip_suffix(')')) - { - return ::std::format!("varchar(some({}), characters)", inner.trim()); - } - if let Some(inner) = normalized - .strip_prefix("char(") - .and_then(|value| value.strip_suffix(')')) - { - return ::std::format!("char({}, characters)", inner.trim()); - } - if let Some(inner) = normalized - .strip_prefix("decimal(") - .and_then(|value| value.strip_suffix(')')) - { - let parts = inner.split(',').map(str::trim).collect::>(); - return match parts.as_slice() { - [precision] => ::std::format!("decimal(some({precision}), none)"), - [precision, scale] => { - ::std::format!("decimal(some({precision}), some({scale}))") - } - _ => normalized, - }; - } - normalized - } - } -} - -fn model_column_default(model: &OrmColumn) -> Option { - model.default_expr.map(normalize_sql_fragment) +fn model_column_default(model: &ColumnCatalog) -> Result, DatabaseError> { + model.default_value() } -fn catalog_column_default(column: &ColumnRef) -> Option { - column - .desc() - .default - .as_ref() - .map(|expr| normalize_sql_fragment(&expr.to_string())) +fn catalog_column_default(column: &ColumnCatalog) -> Result, DatabaseError> { + column.default_value() } -fn model_column_type_matches_catalog(model: &OrmColumn, column: &ColumnRef) -> bool { - canonicalize_model_type(&model.ddl_type) - == normalize_sql_fragment(&column.datatype().to_string()) +fn model_column_type_matches_catalog(model: &ColumnCatalog, column: &ColumnCatalog) -> bool { + model.datatype() == column.datatype() } -fn model_column_matches_catalog(model: &OrmColumn, column: &ColumnRef) -> bool { - model.primary_key == column.desc().is_primary() - && model.unique == column.desc().is_unique() - && model.nullable == column.nullable() +fn model_column_matches_catalog( + model: &ColumnCatalog, + column: &ColumnCatalog, +) -> Result { + Ok(model.desc().is_primary() == column.desc().is_primary() + && model.desc().is_unique() == column.desc().is_unique() + && model.nullable() == column.nullable() && model_column_type_matches_catalog(model, column) - && model_column_default(model) == catalog_column_default(column) + && model_column_default(model)? == catalog_column_default(column)?) } -fn model_column_rename_compatible(model: &OrmColumn, column: &ColumnRef) -> bool { - model.primary_key == column.desc().is_primary() - && model.unique == column.desc().is_unique() - && model.nullable == column.nullable() +fn model_column_rename_compatible( + model: &ColumnCatalog, + column: &ColumnCatalog, +) -> Result { + Ok(model.desc().is_primary() == column.desc().is_primary() + && model.desc().is_unique() == column.desc().is_unique() + && model.nullable() == column.nullable() && model_column_type_matches_catalog(model, column) - && model_column_default(model) == catalog_column_default(column) + && model_column_default(model)? == catalog_column_default(column)?) } fn extract_optional_model(iter: I) -> Result, DatabaseError> @@ -5556,12 +2993,13 @@ where fn extract_optional_row(mut iter: I) -> Result, DatabaseError> where I: ResultIter, - T: for<'a> From<(&'a SchemaRef, Tuple)>, + T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>, { - let schema = iter.schema().clone(); - Ok(match iter.next() { - Some(tuple) => Some(T::from((&schema, tuple?))), + Some(tuple) => { + let tuple = tuple?; + Some(iter.schema(|schema| T::from((schema, tuple)))) + } None => None, }) } @@ -5608,62 +3046,48 @@ fn extract_projected_tuple(tuple: Tuple) -> Result(mut iter: I) -> Result, DatabaseError> -where - I: ResultIter, - T: FromDataValue, -{ - Ok(match iter.next() { - Some(tuple) => Some(extract_value_from_tuple(tuple?)?), - None => None, - }) -} - -fn extract_optional_tuple(mut iter: I) -> Result, DatabaseError> -where - I: ResultIter, - T: FromQueryTuple, -{ - Ok(match iter.next() { - Some(tuple) => Some(extract_projected_tuple(tuple?)?), - None => None, - }) -} - -fn orm_analyze(executor: E) -> Result<(), DatabaseError> { +fn orm_analyze(executor: E) -> Result<(), DatabaseError> { executor - .execute_statement(M::analyze_statement(), &[])? + .execute(&[], |binder, arena| { + binder.bind_analyze(M::table_name().into(), arena) + })? .done() } -fn orm_insert(executor: E, model: &M) -> Result<(), DatabaseError> { - orm_insert_model(executor, M::insert_statement(), model.params()) -} - -fn orm_insert_model( - executor: E, - statement: &Statement, - params: A, -) -> Result<(), DatabaseError> -where - E: StatementSource, - A: AsRef<[(&'static str, DataValue)]>, -{ - executor.execute_statement(statement, params)?.done() +fn orm_insert(executor: E, model: &M) -> Result<(), DatabaseError> { + let params = model.params(); + executor + .execute(&[], |binder, arena| { + bind_orm_insert_model::<_, _, M>(binder, params, arena) + })? + .done() } -fn orm_get( +fn orm_get( executor: E, key: &M::PrimaryKey, ) -> Result, DatabaseError> { - let params = [(M::primary_key_field().placeholder, key.to_data_value())]; - extract_optional_model(executor.execute_statement(M::find_statement(), params)?) -} - -fn orm_list( - executor: E, -) -> Result, DatabaseError> { - Ok(executor - .execute_statement(M::select_statement(), &[])? - .orm::()) + let primary_key = M::primary_key_field(); + let key = key.to_data_value(); + extract_optional_model(bind_orm_context(executor, |ctx| { + let plan: LogicalPlan = ctx + .from::()? + .filter(|expr| { + let column = expr.qualified_column( + M::table_name(), + Field::::new(M::table_name(), primary_key.column), + )?; + column.eq(expr.data_value(key)) + })? + .finish()?; + Ok(plan) + })?) +} + +fn orm_list(executor: E) -> Result, DatabaseError> { + Ok(bind_orm_context(executor, |ctx| { + let plan: LogicalPlan = ctx.from::()?.finish()?; + Ok(plan) + })? + .orm::()) } diff --git a/src/planner/arena.rs b/src/planner/arena.rs new file mode 100644 index 00000000..dfdf4d98 --- /dev/null +++ b/src/planner/arena.rs @@ -0,0 +1,609 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::catalog::{ColumnCatalog, ColumnRef, TableName}; +use crate::types::index::{IndexMeta, IndexMetaRef}; +use crate::types::tuple::Schema; +use std::cell::UnsafeCell; +use std::collections::HashSet; +use std::fmt; + +pub struct TableArena { + dummy_columns: [ColumnCatalog; DUMMY_COLUMN_COUNT], + columns: Vec, + indexes: Vec, + version: usize, +} + +struct TableArenaColumn { + catalog: ColumnCatalog, + live: bool, +} + +struct TableArenaIndex { + meta: IndexMeta, + live: bool, +} + +pub struct TableArenaCell { + value: UnsafeCell, +} + +// SAFETY: table arena mutation is only exposed through database APIs that require +// `&mut Database`; read execution only borrows already-loaded metadata. +unsafe impl Send for TableArenaCell {} +unsafe impl Sync for TableArenaCell {} + +#[derive(Debug)] +pub struct PlanArena<'a> { + table_arena: &'a TableArenaCell, + #[cfg(debug_assertions)] + table_arena_version: usize, + allocated_columns_len: usize, + temp_table_id: usize, + columns: Vec, + indexes: Vec, +} + +pub trait MetaArena { + fn alloc_column(&mut self, column: ColumnCatalog) -> ColumnRef; + + fn alloc_index(&mut self, index: IndexMeta) -> IndexMetaRef; + + fn alloc_columns(&mut self, columns: I) -> Schema + where + Self: Sized, + I: IntoIterator, + { + columns + .into_iter() + .map(|column| self.alloc_column(column)) + .collect() + } + + fn column(&self, column: ColumnRef) -> &ColumnCatalog; + + fn index(&self, index: IndexMetaRef) -> &IndexMeta; + + fn find_column(&self, column: &ColumnCatalog) -> Option; + + fn find_index(&self, index: &IndexMeta) -> Option; +} + +const DUMMY_COLUMN_NAMES: [&str; DUMMY_COLUMN_COUNT] = [ + "TABLE", + "VIEW", + "PLAN", + "FIELD", + "TYPE", + "LEN", + "NULL", + "Key", + "DEFAULT", + "COLUMN_REF", + "INSERTED", + "UPDATED", + "DELETED", + "STATISTICS_META_PATH", + "ADD COLUMN SUCCESS", + "CHANGE COLUMN SUCCESS", + "DROP COLUMN SUCCESS", + "CREATE TABLE SUCCESS", + "CREATE INDEX SUCCESS", + "CREATE VIEW SUCCESS", + "DROP TABLE SUCCESS", + "DROP VIEW SUCCESS", + "DROP INDEX SUCCESS", + "TRUNCATE TABLE SUCCESS", + "COPY FROM SOURCE", + "COPY TO TARGET", +]; +const DUMMY_COLUMN_COUNT: usize = 26; +const DUMMY_COLUMN_BASE: usize = usize::MAX - DUMMY_COLUMN_COUNT + 1; + +impl TableArenaCell { + pub(crate) fn new(value: TableArena) -> Self { + Self { + value: UnsafeCell::new(value), + } + } + + pub(crate) fn borrow(&self) -> &TableArena { + unsafe { &*self.value.get() } + } + + #[allow(clippy::mut_from_ref)] + pub(crate) fn borrow_mut(&self) -> &mut TableArena { + unsafe { &mut *self.value.get() } + } +} + +impl Default for TableArenaCell { + fn default() -> Self { + Self::new(TableArena::default()) + } +} + +impl Default for TableArena { + fn default() -> Self { + Self { + dummy_columns: std::array::from_fn(|i| { + ColumnCatalog::new_dummy(DUMMY_COLUMN_NAMES[i].to_string()) + }), + columns: Vec::new(), + indexes: Vec::new(), + version: 0, + } + } +} + +impl fmt::Debug for TableArenaCell { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.borrow().fmt(f) + } +} + +impl TableArena { + pub(crate) fn alloc_dummy(&self, name: &str) -> ColumnRef { + DUMMY_COLUMN_NAMES + .iter() + .position(|dummy_name| *dummy_name == name) + .map(|index| ColumnRef::new(DUMMY_COLUMN_BASE + index)) + .unwrap_or_else(|| panic!("unknown dummy column: {name}")) + } + + pub fn alloc_table_column( + &mut self, + table_name: TableName, + mut column: ColumnCatalog, + ) -> ColumnRef { + column.set_ref_table(table_name, ulid::Ulid::new(), false); + self.alloc_column(column) + } + + pub(crate) fn alloc_column(&mut self, column: ColumnCatalog) -> ColumnRef { + ::alloc_column(self, column) + } + + pub fn alloc_index(&mut self, index: IndexMeta) -> IndexMetaRef { + ::alloc_index(self, index) + } + + pub(crate) fn column(&self, column: ColumnRef) -> &ColumnCatalog { + ::column(self, column) + } + + pub(crate) fn index(&self, index: IndexMetaRef) -> &IndexMeta { + ::index(self, index) + } + + fn dummy_column(&self, column: ColumnRef) -> Option<&ColumnCatalog> { + column + .pos() + .checked_sub(DUMMY_COLUMN_BASE) + .and_then(|index| self.dummy_columns.get(index)) + } + + pub(crate) fn columns_len(&self) -> usize { + self.columns.len() + } + + pub(crate) fn indexes_len(&self) -> usize { + self.indexes.len() + } + + pub(crate) fn live_columns_len(&self) -> usize { + self.columns.iter().filter(|column| column.live).count() + } + + pub(crate) fn version(&self) -> usize { + self.version + } + + pub(crate) fn recycle_unreferenced_positions(&mut self, live_columns: HashSet) { + let mut changed = false; + + for (pos, column) in self.columns.iter_mut().enumerate() { + let live = live_columns.contains(&pos); + if column.live != live { + column.live = live; + changed = true; + } + } + + if changed { + self.increment_version(); + } + } + + fn increment_version(&mut self) { + self.version = self + .version + .checked_add(1) + .expect("TableArena version overflow"); + } +} + +impl fmt::Debug for TableArena { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TableArena") + .field("columns_len", &self.live_columns_len()) + .field("slots_len", &self.columns_len()) + .field("indexes_len", &self.indexes_len()) + .field("version", &self.version()) + .finish() + } +} + +impl MetaArena for TableArena { + fn alloc_column(&mut self, column: ColumnCatalog) -> ColumnRef { + if let Some(column_ref) = self.find_column(&column) { + return column_ref; + } + + if let Some((pos, slot)) = self + .columns + .iter_mut() + .enumerate() + .find(|(_, column)| !column.live) + { + *slot = TableArenaColumn { + catalog: column, + live: true, + }; + self.increment_version(); + return ColumnRef::new(pos); + } + + let pos = self.columns.len(); + self.columns.push(TableArenaColumn { + catalog: column, + live: true, + }); + self.increment_version(); + ColumnRef::new(pos) + } + + fn alloc_index(&mut self, index: IndexMeta) -> IndexMetaRef { + if let Some(index_ref) = self.find_index(&index) { + return index_ref; + } + + if let Some((pos, slot)) = self + .indexes + .iter_mut() + .enumerate() + .find(|(_, index)| !index.live) + { + *slot = TableArenaIndex { + meta: index, + live: true, + }; + self.increment_version(); + return IndexMetaRef::new(pos); + } + + let pos = self.indexes.len(); + self.indexes.push(TableArenaIndex { + meta: index, + live: true, + }); + self.increment_version(); + IndexMetaRef::new(pos) + } + + fn column(&self, column: ColumnRef) -> &ColumnCatalog { + if let Some(column) = self.dummy_column(column) { + return column; + } + let column = &self.columns[column.pos()]; + if !column.live { + panic!("accessing recycled TableArena column"); + } + &column.catalog + } + + fn index(&self, index: IndexMetaRef) -> &IndexMeta { + let index = &self.indexes[index.pos()]; + if !index.live { + panic!("accessing recycled TableArena index"); + } + &index.meta + } + + fn find_column(&self, column: &ColumnCatalog) -> Option { + self.columns + .iter() + .position(|candidate| candidate.live && candidate.catalog == *column) + .map(ColumnRef::new) + } + + fn find_index(&self, index: &IndexMeta) -> Option { + self.indexes + .iter() + .position(|candidate| candidate.live && candidate.meta == *index) + .map(IndexMetaRef::new) + } +} + +impl<'a> PlanArena<'a> { + pub fn new(table_arena: &'a TableArenaCell) -> Self { + #[cfg(debug_assertions)] + let table_arena_version = table_arena.borrow().version(); + Self { + table_arena, + #[cfg(debug_assertions)] + table_arena_version, + allocated_columns_len: 0, + temp_table_id: 0, + columns: Vec::new(), + indexes: Vec::new(), + } + } + + pub(crate) fn table_arena_cell(&self) -> &'a TableArenaCell { + self.table_arena + } + + pub(crate) fn materialize_into_table_arena(&self) { + self.assert_table_arena_unchanged(); + + let table_arena = self.table_arena.borrow_mut(); + // PlanArena column refs are encoded as a contiguous suffix after the + // TableArena slots that existed when the plan was built. For cached + // view plans we preserve those refs verbatim by appending this suffix + // into TableArena at the same positions. This deliberately skips the + // dead-slot reuse path used by ordinary table metadata allocation: + // a little short-term slack is cheaper than remapping every ColumnRef + // embedded in a view plan, and recycle_unreferenced can reclaim the + // older dead slots for future non-view allocations. + for column in &self.columns { + table_arena.columns.push(TableArenaColumn { + catalog: column.clone(), + live: true, + }); + } + for index in &self.indexes { + table_arena.indexes.push(TableArenaIndex { + meta: index.clone(), + live: true, + }); + } + if !self.columns.is_empty() || !self.indexes.is_empty() { + table_arena.increment_version(); + } + } + + #[cfg(debug_assertions)] + fn assert_table_arena_unchanged(&self) { + let current_version = self.table_arena.borrow().version(); + if current_version != self.table_arena_version { + panic!("TableArena was modified while PlanArena is still active"); + } + } + + #[cfg(not(debug_assertions))] + fn assert_table_arena_unchanged(&self) {} + + pub(crate) fn clone_column(&self, column: ColumnRef) -> ColumnCatalog { + self.column(column).clone() + } + + pub(crate) fn same_column(&self, left: ColumnRef, right: ColumnRef) -> bool { + self.column(left).summary() == self.column(right).summary() + } + + pub(crate) fn nullable_for_join( + &mut self, + column: ColumnRef, + nullable: bool, + ) -> Option { + let source = self.column(column); + if source.nullable() == nullable { + return None; + } + // FIXME + let mut joined = source.clone(); + joined.set_nullable(nullable); + joined.set_in_join(true); + Some(self.alloc_column(joined)) + } + + pub(crate) fn alloc_column(&mut self, column: ColumnCatalog) -> ColumnRef { + ::alloc_column(self, column) + } + + pub(crate) fn allocated_columns_len(&self) -> usize { + self.allocated_columns_len + } + + pub fn alloc_index(&mut self, index: IndexMeta) -> IndexMetaRef { + ::alloc_index(self, index) + } + + pub(crate) fn alloc_dummy(&mut self, name: &str) -> ColumnRef { + self.assert_table_arena_unchanged(); + self.table_arena.borrow().alloc_dummy(name) + } + + pub(crate) fn temp_table(&mut self) -> TableName { + let table_name = format!("_temp_table_{}_", self.temp_table_id); + self.temp_table_id += 1; + table_name.into() + } + + pub fn column(&self, column: ColumnRef) -> &ColumnCatalog { + ::column(self, column) + } + + #[allow(clippy::should_implement_trait)] + pub fn index(&self, index: IndexMetaRef) -> &IndexMeta { + ::index(self, index) + } +} + +impl MetaArena for PlanArena<'_> { + fn alloc_column(&mut self, column: ColumnCatalog) -> ColumnRef { + self.assert_table_arena_unchanged(); + self.allocated_columns_len += 1; + + if let Some(column_ref) = self.find_column(&column) { + return column_ref; + } + + let pos = self.table_arena.borrow().columns_len() + self.columns.len(); + self.columns.push(column); + ColumnRef::new(pos) + } + + fn alloc_index(&mut self, index: IndexMeta) -> IndexMetaRef { + self.assert_table_arena_unchanged(); + + if let Some(index_ref) = self.find_index(&index) { + return index_ref; + } + + let pos = self.table_arena.borrow().indexes_len() + self.indexes.len(); + self.indexes.push(index); + IndexMetaRef::new(pos) + } + + fn column(&self, column: ColumnRef) -> &ColumnCatalog { + self.assert_table_arena_unchanged(); + let table_arena = self.table_arena.borrow(); + if let Some(column) = table_arena.dummy_column(column) { + return column; + } + let table_columns_len = table_arena.columns_len(); + if column.pos() < table_columns_len { + table_arena.column(column) + } else { + &self.columns[column.pos() - table_columns_len] + } + } + + fn index(&self, index: IndexMetaRef) -> &IndexMeta { + self.assert_table_arena_unchanged(); + let table_arena = self.table_arena.borrow(); + let table_indexes_len = table_arena.indexes_len(); + if index.pos() < table_indexes_len { + table_arena.index(index) + } else { + &self.indexes[index.pos() - table_indexes_len] + } + } + + fn find_column(&self, column: &ColumnCatalog) -> Option { + self.assert_table_arena_unchanged(); + let table_arena = self.table_arena.borrow(); + if column.is_persistent_table_column() { + if let Some(column_ref) = table_arena.find_column(column) { + return Some(column_ref); + } + } + let table_columns_len = table_arena.columns_len(); + self.columns + .iter() + .position(|candidate| candidate == column) + .map(|offset| ColumnRef::new(table_columns_len + offset)) + } + + fn find_index(&self, index: &IndexMeta) -> Option { + self.assert_table_arena_unchanged(); + let table_arena = self.table_arena.borrow(); + if let Some(index_ref) = table_arena.find_index(index) { + return Some(index_ref); + } + let table_indexes_len = table_arena.indexes_len(); + self.indexes + .iter() + .position(|candidate| candidate == index) + .map(|offset| IndexMetaRef::new(table_indexes_len + offset)) + } +} + +#[cfg(test)] +mod tests { + use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::types::LogicalType; + + fn column(name: &str) -> ColumnCatalog { + ColumnCatalog::new( + name.to_string(), + true, + ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), + ) + } + + fn table_column(name: &str, is_temp: bool) -> ColumnCatalog { + let mut column = column(name); + column.set_ref_table("t".to_string().into(), ulid::Ulid::new(), is_temp); + column + } + + #[test] + fn table_arena_reuses_recycled_slot() { + let arena = crate::planner::TableArenaCell::default(); + let first = arena.borrow_mut().alloc_column(column("a")); + let second = arena.borrow_mut().alloc_column(column("b")); + + arena + .borrow_mut() + .recycle_unreferenced_positions([first.pos()].into_iter().collect()); + let reused = arena.borrow_mut().alloc_column(column("c")); + + assert_eq!(reused, second); + assert_eq!(arena.borrow().column(reused).name(), "c"); + assert_eq!(arena.borrow().columns_len(), 2); + assert_eq!(arena.borrow().live_columns_len(), 2); + } + + #[test] + fn plan_arena_reuses_only_persistent_table_columns_from_table_arena() { + let table_arena = crate::planner::TableArenaCell::default(); + let persistent = table_column("a", false); + let temp = table_column("b", true); + + let persistent_ref = table_arena.borrow_mut().alloc_column(persistent.clone()); + let temp_table_ref = table_arena.borrow_mut().alloc_column(temp.clone()); + + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + assert_eq!(plan_arena.alloc_column(persistent), persistent_ref); + + let temp_plan_ref = plan_arena.alloc_column(temp); + assert_ne!(temp_plan_ref, temp_table_ref); + assert!(temp_plan_ref.pos() >= table_arena.borrow().columns_len()); + } + + #[test] + fn materializing_plan_arena_preserves_local_column_positions() { + let table_arena = crate::planner::TableArenaCell::default(); + let first = table_arena.borrow_mut().alloc_column(column("a")); + let second = table_arena.borrow_mut().alloc_column(column("b")); + table_arena + .borrow_mut() + .recycle_unreferenced_positions([first.pos()].into_iter().collect()); + + let mut plan_arena = crate::planner::PlanArena::new(&table_arena); + let local = plan_arena.alloc_column(column("c")); + assert_eq!(local.pos(), 2); + + plan_arena.materialize_into_table_arena(); + + assert_eq!(table_arena.borrow().columns_len(), 3); + assert_eq!(table_arena.borrow().column(local).name(), "c"); + assert_eq!(table_arena.borrow().column(first).name(), "a"); + assert_eq!(second.pos(), 1); + } +} diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 1d3660dd..05122885 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -12,23 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod arena; pub mod operator; -use crate::catalog::{ColumnCatalog, ColumnRef, TableName}; +use crate::catalog::TableName; use crate::planner::operator::set_membership::SetMembershipOperator; use crate::planner::operator::union::UnionOperator; use crate::planner::operator::values::ValuesOperator; use crate::planner::operator::{Operator, PhysicalOption}; -use crate::types::tuple::{Schema, SchemaRef}; -use itertools::Itertools; use kite_sql_serde_macros::ReferenceSerialization; -use std::sync::Arc; +use std::hash::{Hash, Hasher}; -#[derive(Debug, Clone)] -pub(crate) enum SchemaOutput { - Schema(Schema), - SchemaRef(SchemaRef), -} +pub use arena::{MetaArena, PlanArena, TableArena, TableArenaCell}; #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub enum Childrens { @@ -96,22 +91,12 @@ impl<'a> Iterator for ChildrensIter<'a> { } } -#[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] +#[derive(Debug)] pub struct LogicalPlan { pub(crate) operator: Operator, pub(crate) childrens: Box, pub(crate) physical_option: Option, - - pub(crate) _output_schema_ref: Option, -} - -impl SchemaOutput { - pub(crate) fn columns(&self) -> impl Iterator { - match self { - SchemaOutput::Schema(schema) => schema.iter(), - SchemaOutput::SchemaRef(schema_ref) => schema_ref.iter(), - } - } + output_schema: Option, } impl LogicalPlan { @@ -120,7 +105,7 @@ impl LogicalPlan { operator, childrens: Box::new(childrens), physical_option: None, - _output_schema_ref: None, + output_schema: None, } } @@ -128,22 +113,6 @@ impl LogicalPlan { std::mem::replace(self, Self::new(Operator::Dummy, Childrens::None)) } - pub(crate) fn reset_output_schema_cache(&mut self) { - self._output_schema_ref = None; - } - - pub(crate) fn reset_output_schema_cache_recursive(&mut self) { - self.reset_output_schema_cache(); - match self.childrens.as_mut() { - Childrens::Only(child) => child.reset_output_schema_cache_recursive(), - Childrens::Twins { left, right } => { - left.reset_output_schema_cache_recursive(); - right.reset_output_schema_cache_recursive(); - } - Childrens::None => (), - } - } - pub fn referenced_table(&self) -> Vec { fn collect_table(plan: &LogicalPlan, results: &mut Vec) { if let Operator::TableScan(op) = &plan.operator { @@ -159,56 +128,93 @@ impl LogicalPlan { tables } - pub(crate) fn _output_schema_direct( - operator: &Operator, - mut childrens_iter: ChildrensIter, - ) -> SchemaOutput { + pub(crate) fn visit_column_refs(&self, arena: &mut A, f: &mut F) + where + A: MetaArena, + F: FnMut(&crate::catalog::ColumnRef) + ?Sized, + { + self.operator + .visit_referenced_columns(arena, &mut |_, column| { + f(column); + true + }); + for child in self.childrens.iter() { + child.visit_column_refs(arena, f); + } + } + + pub fn output_schema<'plan>( + &'plan mut self, + arena: &mut PlanArena, + ) -> &'plan crate::types::tuple::Schema { + let LogicalPlan { + operator, + childrens, + output_schema, + .. + } = self; + output_schema.get_or_insert_with(|| Self::compute_output_schema(operator, childrens, arena)) + } + + pub fn take_schema(&mut self, arena: &mut PlanArena) -> crate::types::tuple::Schema { + let LogicalPlan { + operator, + childrens, + output_schema, + .. + } = self; + output_schema + .take() + .unwrap_or_else(|| Self::compute_output_schema(operator, childrens, arena)) + } + + fn compute_output_schema( + operator: &mut Operator, + childrens: &mut Childrens, + arena: &mut PlanArena, + ) -> crate::types::tuple::Schema { match operator { Operator::Filter(_) | Operator::Sort(_) | Operator::Limit(_) | Operator::TopK(_) - | Operator::ScalarSubquery(_) => childrens_iter.next().unwrap().output_schema_direct(), - Operator::ScalarApply(_) | Operator::Join(_) => { - let mut columns = Vec::new(); - - for plan in childrens_iter { - for column in plan.output_schema_direct().columns() { - columns.push(column.clone()); - } + | Operator::ScalarSubquery(_) => match childrens { + Childrens::Only(child) => child.output_schema(arena).clone(), + _ => unreachable!(), + }, + Operator::ScalarApply(_) | Operator::Join(_) => match childrens { + Childrens::Twins { left, right } => { + let mut schema = left.output_schema(arena).clone(); + schema.extend_from_slice(right.output_schema(arena)); + schema } - SchemaOutput::Schema(columns) - } + _ => unreachable!(), + }, Operator::MarkApply(op) => { - let mut columns = Vec::new(); - - if let Some(left) = childrens_iter.next() { - for column in left.output_schema_direct().columns() { - columns.push(column.clone()); - } - } - columns.push(op.output_column().clone()); - - SchemaOutput::Schema(columns) - } - Operator::Aggregate(op) => SchemaOutput::Schema( - op.agg_calls - .iter() - .chain(op.groupby_exprs.iter()) - .map(|expr| expr.output_column()) - .collect_vec(), - ), - Operator::Project(op) => SchemaOutput::Schema( - op.exprs - .iter() - .map(|expr| expr.output_column()) - .collect_vec(), - ), - Operator::TableScan(op) => { - SchemaOutput::Schema(op.columns.iter().cloned().collect_vec()) + let mut schema = match childrens { + Childrens::Only(left) => left.output_schema(arena).clone(), + Childrens::Twins { left, .. } => left.output_schema(arena).clone(), + Childrens::None => Vec::new(), + }; + schema.push(*op.output_column()); + schema } + Operator::Aggregate(op) => op + .agg_calls + .iter() + .chain(op.groupby_exprs.iter()) + .map(|expr| expr.output_column_ref(arena)) + .collect(), + Operator::Project(op) => op + .exprs + .iter() + .map(|expr| expr.output_column_ref(arena)) + .collect(), + Operator::TableScan(op) => op.columns.clone(), Operator::FunctionScan(op) => { - SchemaOutput::SchemaRef(op.table_function.output_schema().clone()) + let mut schema = Vec::new(); + op.table_function.output_schema_into(&mut schema); + schema } Operator::Values(ValuesOperator { schema_ref, .. }) | Operator::Union(UnionOperator { @@ -218,90 +224,72 @@ impl LogicalPlan { | Operator::SetMembership(SetMembershipOperator { left_schema_ref: schema_ref, .. - }) => SchemaOutput::SchemaRef(schema_ref.clone()), - Operator::Dummy => SchemaOutput::Schema(vec![]), - Operator::ShowTable => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("TABLE".to_string()), - )]), - Operator::ShowView => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("VIEW".to_string()), - )]), - Operator::Explain => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("PLAN".to_string()), - )]), - Operator::Describe(_) => SchemaOutput::Schema(vec![ - ColumnRef::from(ColumnCatalog::new_dummy("FIELD".to_string())), - ColumnRef::from(ColumnCatalog::new_dummy("TYPE".to_string())), - ColumnRef::from(ColumnCatalog::new_dummy("LEN".to_string())), - ColumnRef::from(ColumnCatalog::new_dummy("NULL".to_string())), - ColumnRef::from(ColumnCatalog::new_dummy("Key".to_string())), - ColumnRef::from(ColumnCatalog::new_dummy("DEFAULT".to_string())), - ]), - Operator::Insert(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("INSERTED".to_string()), - )]), - Operator::Update(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("UPDATED".to_string()), - )]), - Operator::Delete(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("DELETED".to_string()), - )]), - Operator::Analyze(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("STATISTICS_META_PATH".to_string()), - )]), - Operator::AddColumn(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("ADD COLUMN SUCCESS".to_string()), - )]), - Operator::ChangeColumn(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("CHANGE COLUMN SUCCESS".to_string()), - )]), - Operator::DropColumn(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("DROP COLUMN SUCCESS".to_string()), - )]), - Operator::CreateTable(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("CREATE TABLE SUCCESS".to_string()), - )]), - Operator::CreateIndex(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("CREATE INDEX SUCCESS".to_string()), - )]), - Operator::CreateView(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("CREATE VIEW SUCCESS".to_string()), - )]), - Operator::DropTable(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("DROP TABLE SUCCESS".to_string()), - )]), - Operator::DropView(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("DROP VIEW SUCCESS".to_string()), - )]), - Operator::DropIndex(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("DROP INDEX SUCCESS".to_string()), - )]), - Operator::Truncate(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("TRUNCATE TABLE SUCCESS".to_string()), - )]), - Operator::CopyFromFile(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("COPY FROM SOURCE".to_string()), - )]), - Operator::CopyToFile(_) => SchemaOutput::Schema(vec![ColumnRef::from( - ColumnCatalog::new_dummy("COPY TO TARGET".to_string()), - )]), + }) => schema_ref.clone(), + Operator::Dummy => Vec::new(), + Operator::ShowTable => Self::dummy_schema(arena, ["TABLE"]), + Operator::ShowView => Self::dummy_schema(arena, ["VIEW"]), + Operator::Explain => Self::dummy_schema(arena, ["PLAN"]), + Operator::Describe(_) => Self::dummy_schema( + arena, + [ + "FIELD", + "TYPE", + "LEN", + "NULL", + "Key", + "DEFAULT", + "COLUMN_REF", + ], + ), + Operator::Insert(_) => Self::dummy_schema(arena, ["INSERTED"]), + Operator::Update(_) => Self::dummy_schema(arena, ["UPDATED"]), + Operator::Delete(_) => Self::dummy_schema(arena, ["DELETED"]), + Operator::Analyze(_) => Self::dummy_schema(arena, ["STATISTICS_META_PATH"]), + Operator::AddColumn(_) => Self::dummy_schema(arena, ["ADD COLUMN SUCCESS"]), + Operator::ChangeColumn(_) => Self::dummy_schema(arena, ["CHANGE COLUMN SUCCESS"]), + Operator::DropColumn(_) => Self::dummy_schema(arena, ["DROP COLUMN SUCCESS"]), + Operator::CreateTable(_) => Self::dummy_schema(arena, ["CREATE TABLE SUCCESS"]), + Operator::CreateIndex(_) => Self::dummy_schema(arena, ["CREATE INDEX SUCCESS"]), + Operator::CreateView(_) => Self::dummy_schema(arena, ["CREATE VIEW SUCCESS"]), + Operator::DropTable(_) => Self::dummy_schema(arena, ["DROP TABLE SUCCESS"]), + Operator::DropView(_) => Self::dummy_schema(arena, ["DROP VIEW SUCCESS"]), + Operator::DropIndex(_) => Self::dummy_schema(arena, ["DROP INDEX SUCCESS"]), + Operator::Truncate(_) => Self::dummy_schema(arena, ["TRUNCATE TABLE SUCCESS"]), + #[cfg(feature = "copy")] + Operator::CopyFromFile(_) => Self::dummy_schema(arena, ["COPY FROM SOURCE"]), + #[cfg(feature = "copy")] + Operator::CopyToFile(_) => Self::dummy_schema(arena, ["COPY TO TARGET"]), } } - pub(crate) fn output_schema_direct(&self) -> SchemaOutput { - Self::_output_schema_direct(&self.operator, self.childrens.iter()) + fn dummy_schema( + arena: &mut PlanArena, + names: [&str; N], + ) -> crate::types::tuple::Schema { + names + .into_iter() + .map(|name| arena.alloc_dummy(name)) + .collect() } - pub fn output_schema(&mut self) -> &SchemaRef { - self._output_schema_ref.get_or_insert_with(|| { - match Self::_output_schema_direct(&self.operator, self.childrens.iter()) { - SchemaOutput::Schema(schema) => Arc::new(schema), - SchemaOutput::SchemaRef(schema_ref) => schema_ref.clone(), + pub fn reset_output_schema_cache(&mut self) { + self.output_schema = None; + } + + pub fn reset_output_schema_cache_recursive(&mut self) { + self.reset_output_schema_cache(); + match self.childrens.as_mut() { + Childrens::Only(child) => child.reset_output_schema_cache_recursive(), + Childrens::Twins { left, right } => { + left.reset_output_schema_cache_recursive(); + right.reset_output_schema_cache_recursive(); } - }) + Childrens::None => (), + } } - pub fn explain(&self, indentation: usize) -> String { + #[allow(clippy::only_used_in_recursion)] + pub fn explain(&self, arena: &mut PlanArena, indentation: usize) -> String { let mut result = format!("{:indent$}{}", "", self.operator, indent = indentation); if let Some(physical_option) = &self.physical_option { @@ -309,10 +297,106 @@ impl LogicalPlan { } for child in self.childrens.iter() { - result.push('\n'); - result.push_str(&child.explain(indentation + 2)); + let child = child.explain(arena, indentation + 2); + result.push(' '); + result.push_str(child.trim_start()); } result } } + +impl Clone for LogicalPlan { + fn clone(&self) -> Self { + Self { + operator: self.operator.clone(), + childrens: self.childrens.clone(), + physical_option: self.physical_option.clone(), + output_schema: None, + } + } +} + +impl PartialEq for LogicalPlan { + fn eq(&self, other: &Self) -> bool { + self.operator == other.operator + && self.childrens == other.childrens + && self.physical_option == other.physical_option + } +} + +impl Eq for LogicalPlan {} + +impl Hash for LogicalPlan { + fn hash(&self, state: &mut H) { + self.operator.hash(state); + self.childrens.hash(state); + self.physical_option.hash(state); + } +} + +impl crate::serdes::ReferenceSerialization for LogicalPlan { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut crate::serdes::ReferenceTables, + arena: &A, + ) -> Result<(), crate::errors::DatabaseError> { + crate::serdes::ReferenceSerialization::encode( + &self.operator, + writer, + is_direct, + reference_tables, + arena, + )?; + crate::serdes::ReferenceSerialization::encode( + &self.childrens, + writer, + is_direct, + reference_tables, + arena, + )?; + crate::serdes::ReferenceSerialization::encode( + &self.physical_option, + writer, + is_direct, + reference_tables, + arena, + ) + } + + fn decode( + reader: &mut R, + context: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, + reference_tables: &crate::serdes::ReferenceTables, + arena: &mut A, + ) -> Result { + let operator = ::decode( + reader, + context, + reference_tables, + arena, + )?; + let childrens = as crate::serdes::ReferenceSerialization>::decode( + reader, + context, + reference_tables, + arena, + )?; + let physical_option = + as crate::serdes::ReferenceSerialization>::decode( + reader, + context, + reference_tables, + arena, + )?; + + Ok(Self { + operator, + childrens, + physical_option, + output_schema: None, + }) + } +} diff --git a/src/planner/operator/copy_from_file.rs b/src/planner/operator/copy_from_file.rs index b15df3c4..75439bc9 100644 --- a/src/planner/operator/copy_from_file.rs +++ b/src/planner/operator/copy_from_file.rs @@ -14,7 +14,7 @@ use crate::binder::copy::ExtSource; use crate::catalog::TableName; -use crate::types::tuple::SchemaRef; +use crate::types::tuple::Schema; use itertools::Itertools; use kite_sql_serde_macros::ReferenceSerialization; use std::fmt; @@ -24,16 +24,12 @@ use std::fmt::Formatter; pub struct CopyFromFileOperator { pub table: TableName, pub source: ExtSource, - pub schema_ref: SchemaRef, + pub schema_ref: Schema, } impl fmt::Display for CopyFromFileOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let columns = self - .schema_ref - .iter() - .map(|column| column.name().to_string()) - .join(", "); + let columns = self.schema_ref.iter().join(", "); write!( f, "Copy {} -> {} [{}]", diff --git a/src/planner/operator/copy_to_file.rs b/src/planner/operator/copy_to_file.rs index 5e823e3e..218a88bb 100644 --- a/src/planner/operator/copy_to_file.rs +++ b/src/planner/operator/copy_to_file.rs @@ -13,8 +13,6 @@ // limitations under the License. use crate::binder::copy::ExtSource; -use crate::types::tuple::SchemaRef; -use itertools::Itertools; use kite_sql_serde_macros::ReferenceSerialization; use std::fmt; use std::fmt::Formatter; @@ -22,17 +20,11 @@ use std::fmt::Formatter; #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub struct CopyToFileOperator { pub target: ExtSource, - pub schema_ref: SchemaRef, } impl fmt::Display for CopyToFileOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let columns = self - .schema_ref - .iter() - .map(|column| column.name().to_string()) - .join(", "); - write!(f, "Copy To {} [{}]", self.target.path.display(), columns)?; + write!(f, "Copy To {}", self.target.path.display())?; Ok(()) } diff --git a/src/planner/operator/create_index.rs b/src/planner/operator/create_index.rs index 8df010a6..5ffa50da 100644 --- a/src/planner/operator/create_index.rs +++ b/src/planner/operator/create_index.rs @@ -31,11 +31,7 @@ pub struct CreateIndexOperator { impl fmt::Display for CreateIndexOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let columns = self - .columns - .iter() - .map(|column| column.name().to_string()) - .join(", "); + let columns = self.columns.iter().join(", "); write!( f, "Create Index On {} -> [{}], If Not Exists: {}", diff --git a/src/planner/operator/delete.rs b/src/planner/operator/delete.rs index 394c6206..cdce5b4f 100644 --- a/src/planner/operator/delete.rs +++ b/src/planner/operator/delete.rs @@ -20,6 +20,7 @@ use std::fmt::Formatter; #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub struct DeleteOperator { pub table_name: TableName, + // FIXME // for column pruning pub primary_keys: Vec, } diff --git a/src/planner/operator/mod.rs b/src/planner/operator/mod.rs index b16e0f24..bf7244d8 100644 --- a/src/planner/operator/mod.rs +++ b/src/planner/operator/mod.rs @@ -15,7 +15,9 @@ pub mod aggregate; pub mod alter_table; pub mod analyze; +#[cfg(feature = "copy")] pub mod copy_from_file; +#[cfg(feature = "copy")] pub mod copy_to_file; pub mod create_index; pub mod create_table; @@ -51,10 +53,11 @@ use self::{ table_scan::TableScanOperator, }; use crate::catalog::ColumnRef; -use crate::expression::ScalarExpression; use crate::planner::operator::alter_table::drop_column::DropColumnOperator; use crate::planner::operator::analyze::AnalyzeOperator; +#[cfg(feature = "copy")] use crate::planner::operator::copy_from_file::CopyFromFileOperator; +#[cfg(feature = "copy")] use crate::planner::operator::copy_to_file::CopyToFileOperator; use crate::planner::operator::create_index::CreateIndexOperator; use crate::planner::operator::create_table::CreateTableOperator; @@ -74,6 +77,7 @@ use crate::planner::operator::truncate::TruncateOperator; use crate::planner::operator::union::UnionOperator; use crate::planner::operator::update::UpdateOperator; use crate::planner::operator::values::ValuesOperator; +use crate::planner::{MetaArena, PlanArena}; use crate::types::index::IndexInfo; use kite_sql_serde_macros::ReferenceSerialization; use std::fmt; @@ -119,7 +123,9 @@ pub enum Operator { DropIndex(DropIndexOperator), Truncate(TruncateOperator), // Copy + #[cfg(feature = "copy")] CopyFromFile(CopyFromFileOperator), + #[cfg(feature = "copy")] CopyToFile(CopyToFileOperator), } @@ -182,123 +188,43 @@ pub enum PlanImpl { DropTable, Truncate, Show, + #[cfg(feature = "copy")] CopyFromFile, + #[cfg(feature = "copy")] CopyToFile, Analyze, } impl Operator { - pub fn output_exprs(&self, output_exprs: &mut Vec) -> bool { - match self { - Operator::Dummy => false, - Operator::Aggregate(op) => { - output_exprs.clear(); - output_exprs.extend(op.agg_calls.iter().chain(op.groupby_exprs.iter()).cloned()); - true - } - Operator::ScalarApply(_) - | Operator::MarkApply(_) - | Operator::Filter(_) - | Operator::Join(_) - | Operator::ScalarSubquery(_) => false, - Operator::Project(op) => { - output_exprs.clear(); - output_exprs.extend(op.exprs.iter().cloned()); - true - } - Operator::TableScan(op) => { - output_exprs.clear(); - output_exprs.extend(op.columns.iter().enumerate().map(|(position, column)| { - ScalarExpression::column_expr(column.clone(), position) - })); - true - } - Operator::Sort(_) | Operator::Limit(_) | Operator::TopK(_) => false, - Operator::Values(ValuesOperator { schema_ref, .. }) - | Operator::Union(UnionOperator { - left_schema_ref: schema_ref, - .. - }) - | Operator::SetMembership(SetMembershipOperator { - left_schema_ref: schema_ref, - .. - }) => { - output_exprs.clear(); - output_exprs.extend( - schema_ref - .iter() - .cloned() - .enumerate() - .map(|(position, column)| ScalarExpression::column_expr(column, position)), - ); - true - } - Operator::FunctionScan(op) => { - output_exprs.clear(); - output_exprs.extend( - op.table_function - .inner - .output_schema() - .iter() - .enumerate() - .map(|(position, column)| { - ScalarExpression::column_expr(column.clone(), position) - }), - ); - true - } - Operator::ShowTable - | Operator::ShowView - | Operator::Explain - | Operator::Describe(_) - | Operator::Insert(_) - | Operator::Update(_) - | Operator::Delete(_) - | Operator::Analyze(_) - | Operator::AddColumn(_) - | Operator::ChangeColumn(_) - | Operator::DropColumn(_) - | Operator::CreateTable(_) - | Operator::CreateIndex(_) - | Operator::CreateView(_) - | Operator::DropTable(_) - | Operator::DropView(_) - | Operator::DropIndex(_) - | Operator::Truncate(_) - | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) => false, - } - } - - pub fn visit_referenced_columns( + pub fn visit_referenced_columns( &self, - only_column_ref: bool, - f: &mut impl FnMut(&ColumnRef) -> bool, + arena: &mut A, + f: &mut impl FnMut(&mut A, &ColumnRef) -> bool, ) -> bool { match self { Operator::Aggregate(op) => op .agg_calls .iter() .chain(op.groupby_exprs.iter()) - .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), + .all(|expr| expr.visit_referenced_columns(arena, f)), Operator::ScalarApply(_) => true, Operator::MarkApply(op) => op .predicates() .iter() - .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), - Operator::Filter(op) => op.predicate.visit_referenced_columns(only_column_ref, f), + .all(|expr| expr.visit_referenced_columns(arena, f)), + Operator::Filter(op) => op.predicate.visit_referenced_columns(arena, f), Operator::Join(op) => { if let JoinCondition::On { on, filter } = &op.on { for (left_expr, right_expr) in on { - if !left_expr.visit_referenced_columns(only_column_ref, f) - || !right_expr.visit_referenced_columns(only_column_ref, f) + if !left_expr.visit_referenced_columns(arena, f) + || !right_expr.visit_referenced_columns(arena, f) { return false; } } if let Some(filter_expr) = filter { - return filter_expr.visit_referenced_columns(only_column_ref, f); + return filter_expr.visit_referenced_columns(arena, f); } } true @@ -306,25 +232,27 @@ impl Operator { Operator::Project(op) => op .exprs .iter() - .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), + .all(|expr| expr.visit_referenced_columns(arena, f)), Operator::ScalarSubquery(_) => true, - Operator::TableScan(op) => op.columns.iter().all(f), + Operator::TableScan(op) => op.columns.iter().all(|column| f(arena, column)), Operator::FunctionScan(op) => op .table_function .args .iter() - .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), + .all(|expr| expr.visit_referenced_columns(arena, f)), Operator::Sort(op) => op .sort_fields .iter() .map(|field| &field.expr) - .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), + .all(|expr| expr.visit_referenced_columns(arena, f)), Operator::TopK(op) => op .sort_fields .iter() .map(|field| &field.expr) - .all(|expr| expr.visit_referenced_columns(only_column_ref, f)), - Operator::Values(ValuesOperator { schema_ref, .. }) => schema_ref.iter().all(f), + .all(|expr| expr.visit_referenced_columns(arena, f)), + Operator::Values(ValuesOperator { schema_ref, .. }) => { + schema_ref.iter().all(|column| f(arena, column)) + } Operator::Union(UnionOperator { left_schema_ref, _right_schema_ref, @@ -336,9 +264,9 @@ impl Operator { }) => left_schema_ref .iter() .chain(_right_schema_ref.iter()) - .all(f), + .all(|column| f(arena, column)), Operator::Analyze(_) => true, - Operator::Delete(op) => op.primary_keys.iter().all(f), + Operator::Delete(op) => op.primary_keys.iter().all(|column| f(arena, column)), Operator::Dummy | Operator::Limit(_) | Operator::ShowTable @@ -356,19 +284,19 @@ impl Operator { | Operator::DropTable(_) | Operator::DropView(_) | Operator::DropIndex(_) - | Operator::Truncate(_) - | Operator::CopyFromFile(_) - | Operator::CopyToFile(_) => true, + | Operator::Truncate(_) => true, + #[cfg(feature = "copy")] + Operator::CopyFromFile(_) | Operator::CopyToFile(_) => true, } } pub fn any_referenced_column( &self, - only_column_ref: bool, + arena: &mut PlanArena, mut predicate: impl FnMut(&ColumnRef) -> bool, ) -> bool { let mut found = false; - self.visit_referenced_columns(only_column_ref, &mut |column| { + self.visit_referenced_columns(arena, &mut |_, column| { found = predicate(column); !found }); @@ -377,11 +305,11 @@ impl Operator { pub fn all_referenced_columns( &self, - only_column_ref: bool, + arena: &mut PlanArena, mut predicate: impl FnMut(&ColumnRef) -> bool, ) -> bool { let mut all = true; - self.visit_referenced_columns(only_column_ref, &mut |column| { + self.visit_referenced_columns(arena, &mut |_, column| { all = predicate(column); all }); @@ -424,7 +352,9 @@ impl fmt::Display for Operator { Operator::DropView(op) => write!(f, "{op}"), Operator::DropIndex(op) => write!(f, "{op}"), Operator::Truncate(op) => write!(f, "{op}"), + #[cfg(feature = "copy")] Operator::CopyFromFile(op) => write!(f, "{op}"), + #[cfg(feature = "copy")] Operator::CopyToFile(op) => write!(f, "{op}"), Operator::Union(op) => write!(f, "{op}"), Operator::SetMembership(op) => write!(f, "{op}"), @@ -492,7 +422,9 @@ impl fmt::Display for PlanImpl { PlanImpl::DropTable => write!(f, "DropTable"), PlanImpl::Truncate => write!(f, "Truncate"), PlanImpl::Show => write!(f, "Show"), + #[cfg(feature = "copy")] PlanImpl::CopyFromFile => write!(f, "CopyFromFile"), + #[cfg(feature = "copy")] PlanImpl::CopyToFile => write!(f, "CopyToFile"), PlanImpl::Analyze => write!(f, "Analyze"), } diff --git a/src/planner/operator/set_membership.rs b/src/planner/operator/set_membership.rs index 5f34549d..0799eec1 100644 --- a/src/planner/operator/set_membership.rs +++ b/src/planner/operator/set_membership.rs @@ -14,7 +14,7 @@ use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; -use crate::types::tuple::SchemaRef; +use crate::types::tuple::Schema; use itertools::Itertools; use kite_sql_serde_macros::ReferenceSerialization; use std::fmt; @@ -27,7 +27,7 @@ pub enum SetMembershipKind { } impl SetMembershipKind { - fn name(self) -> &'static str { + pub(crate) fn name(self) -> &'static str { match self { Self::Except => "Except", Self::Intersect => "Intersect", @@ -38,16 +38,16 @@ impl SetMembershipKind { #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub struct SetMembershipOperator { pub kind: SetMembershipKind, - pub left_schema_ref: SchemaRef, + pub left_schema_ref: Schema, // mainly use `left_schema` as output and `right_schema` for `column pruning` - pub _right_schema_ref: SchemaRef, + pub _right_schema_ref: Schema, } impl SetMembershipOperator { pub fn build( kind: SetMembershipKind, - left_schema_ref: SchemaRef, - right_schema_ref: SchemaRef, + left_schema_ref: Schema, + right_schema_ref: Schema, left_plan: LogicalPlan, right_plan: LogicalPlan, ) -> LogicalPlan { @@ -67,11 +67,7 @@ impl SetMembershipOperator { impl fmt::Display for SetMembershipOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let schema = self - .left_schema_ref - .iter() - .map(|column| column.name().to_string()) - .join(", "); + let schema = self.left_schema_ref.iter().join(", "); write!(f, "{}: [{schema}]", self.kind.name())?; diff --git a/src/planner/operator/sort.rs b/src/planner/operator/sort.rs index 720f0e7e..5b842a0e 100644 --- a/src/planner/operator/sort.rs +++ b/src/planner/operator/sort.rs @@ -33,6 +33,32 @@ impl SortField { nulls_first, } } + + pub fn asc(mut self) -> Self { + self.asc = true; + self + } + + pub fn desc(mut self) -> Self { + self.asc = false; + self + } + + pub fn nulls_first(mut self) -> Self { + self.nulls_first = true; + self + } + + pub fn nulls_last(mut self) -> Self { + self.nulls_first = false; + self + } +} + +impl From for SortField { + fn from(expr: ScalarExpression) -> Self { + SortField::new(expr, true, false) + } } #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] diff --git a/src/planner/operator/table_scan.rs b/src/planner/operator/table_scan.rs index 637ca8a3..84853d20 100644 --- a/src/planner/operator/table_scan.rs +++ b/src/planner/operator/table_scan.rs @@ -17,7 +17,7 @@ use crate::catalog::{ColumnRef, TableCatalog, TableName}; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; use crate::planner::operator::sort::SortField; -use crate::planner::{Childrens, LogicalPlan}; +use crate::planner::{Childrens, LogicalPlan, PlanArena}; use crate::storage::Bounds; use crate::types::index::IndexInfo; use itertools::Itertools; @@ -43,26 +43,28 @@ impl TableScanOperator { table_name: TableName, table_catalog: &TableCatalog, with_pk: bool, + arena: &PlanArena, ) -> Result { // Fill all Columns in TableCatalog by default - let columns = table_catalog.columns().cloned().collect(); + let columns = table_catalog.columns().copied().collect_vec(); let mut index_infos = Vec::with_capacity(table_catalog.indexes.len()); - for index_meta in table_catalog.indexes.iter() { + for index_ref in table_catalog.indexes.iter().copied() { + let index_meta = arena.index(index_ref); let mut sort_fields = Vec::with_capacity(index_meta.column_ids.len()); for col_id in &index_meta.column_ids { - let column = table_catalog.get_column_by_id(col_id).ok_or_else(|| { + let column_ref = table_catalog.get_column_by_id(col_id).ok_or_else(|| { DatabaseError::column_not_found(format!("index column id: {col_id} not found")) })?; sort_fields.push(SortField { - expr: ScalarExpression::column_expr(column.clone(), sort_fields.len()), + expr: ScalarExpression::column_expr(column_ref, sort_fields.len()), asc: true, nulls_first: false, }) } index_infos.push(IndexInfo { - meta: index_meta.clone(), + meta: index_ref, sort_option: SortOption::OrderBy { fields: sort_fields, ignore_prefix_len: 0, @@ -90,11 +92,7 @@ impl TableScanOperator { impl fmt::Display for TableScanOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let projection_columns = self - .columns - .iter() - .map(|column| column.name().to_string()) - .join(", "); + let projection_columns = self.columns.iter().join(", "); let (offset, limit) = self.limit; write!( diff --git a/src/planner/operator/union.rs b/src/planner/operator/union.rs index a196e796..e3f7838f 100644 --- a/src/planner/operator/union.rs +++ b/src/planner/operator/union.rs @@ -14,7 +14,7 @@ use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; -use crate::types::tuple::SchemaRef; +use crate::types::tuple::Schema; use itertools::Itertools; use kite_sql_serde_macros::ReferenceSerialization; use std::fmt; @@ -22,15 +22,15 @@ use std::fmt::Formatter; #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub struct UnionOperator { - pub left_schema_ref: SchemaRef, + pub left_schema_ref: Schema, // mainly use `left_schema` as output and `right_schema` for `column pruning` - pub _right_schema_ref: SchemaRef, + pub _right_schema_ref: Schema, } impl UnionOperator { pub fn build( - left_schema_ref: SchemaRef, - right_schema_ref: SchemaRef, + left_schema_ref: Schema, + right_schema_ref: Schema, left_plan: LogicalPlan, right_plan: LogicalPlan, ) -> LogicalPlan { @@ -49,11 +49,7 @@ impl UnionOperator { impl fmt::Display for UnionOperator { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let schema = self - .left_schema_ref - .iter() - .map(|column| column.name().to_string()) - .join(", "); + let schema = self.left_schema_ref.iter().join(", "); write!(f, "Union: [{schema}]")?; diff --git a/src/planner/operator/update.rs b/src/planner/operator/update.rs index 2c767299..91b735b4 100644 --- a/src/planner/operator/update.rs +++ b/src/planner/operator/update.rs @@ -30,7 +30,7 @@ impl fmt::Display for UpdateOperator { let values = self .value_exprs .iter() - .map(|(column, expr)| format!("{} -> {}", column.full_name(), expr)) + .map(|(column, expr)| format!("{column} -> {expr}")) .join(", "); write!(f, "Update {} set {}", self.table_name, values)?; diff --git a/src/planner/operator/values.rs b/src/planner/operator/values.rs index 624554b6..9140713d 100644 --- a/src/planner/operator/values.rs +++ b/src/planner/operator/values.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::types::tuple::SchemaRef; +use crate::types::tuple::Schema; use crate::types::value::DataValue; use itertools::Itertools; use kite_sql_serde_macros::ReferenceSerialization; @@ -22,7 +22,7 @@ use std::fmt::Formatter; #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub struct ValuesOperator { pub rows: Vec>, - pub schema_ref: SchemaRef, + pub schema_ref: Schema, } impl fmt::Display for ValuesOperator { diff --git a/src/python.rs b/src/python.rs index 4c959bf6..12dd9558 100644 --- a/src/python.rs +++ b/src/python.rs @@ -21,7 +21,7 @@ use crate::storage::lmdb::LmdbStorage; use crate::storage::memory::MemoryStorage; #[cfg(feature = "rocksdb")] use crate::storage::rocksdb::RocksStorage; -use crate::types::tuple::{SchemaRef, Tuple}; +use crate::types::tuple::{SchemaView, Tuple}; use crate::types::value::DataValue; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; @@ -82,7 +82,7 @@ fn tuple_to_python_row(py: Python<'_>, tuple: &Tuple) -> PyResult { Ok(row.into_any().unbind()) } -fn schema_to_python(py: Python<'_>, schema: &SchemaRef) -> PyResult> { +fn schema_to_python(py: Python<'_>, schema: &SchemaView<'_, '_>) -> PyResult> { schema .iter() .map(|col| { @@ -131,6 +131,26 @@ impl PythonDatabaseInner { } } } + + fn ddl(&mut self, sql: &str) -> Result<(), DatabaseError> { + match self { + #[cfg(feature = "lmdb")] + PythonDatabaseInner::Lmdb(db) => db.ddl(sql), + PythonDatabaseInner::Memory(db) => db.ddl(sql), + #[cfg(feature = "rocksdb")] + PythonDatabaseInner::Rocks(db) => db.ddl(sql), + } + } + + fn analyze(&mut self, table_name: &str) -> Result<(), DatabaseError> { + match self { + #[cfg(feature = "lmdb")] + PythonDatabaseInner::Lmdb(db) => db.analyze(table_name), + PythonDatabaseInner::Memory(db) => db.analyze(table_name), + #[cfg(feature = "rocksdb")] + PythonDatabaseInner::Rocks(db) => db.analyze(table_name), + } + } } enum PythonResultIterInner { @@ -152,13 +172,13 @@ impl PythonResultIterInner { } } - fn schema(&self) -> &SchemaRef { + fn schema(&self, f: impl FnOnce(&SchemaView<'_, '_>) -> R) -> R { match self { #[cfg(feature = "lmdb")] - PythonResultIterInner::Lmdb(iter) => iter.schema(), - PythonResultIterInner::Memory(iter) => iter.schema(), + PythonResultIterInner::Lmdb(iter) => iter.schema(f), + PythonResultIterInner::Memory(iter) => iter.schema(f), #[cfg(feature = "rocksdb")] - PythonResultIterInner::Rocks(iter) => iter.schema(), + PythonResultIterInner::Rocks(iter) => iter.schema(f), } } @@ -198,11 +218,12 @@ impl PythonDatabase { .map_err(to_py_err)?, ), other => { - let mut expected = Vec::new(); - #[cfg(feature = "rocksdb")] - expected.push("rocksdb"); - #[cfg(feature = "lmdb")] - expected.push("lmdb"); + let expected = [ + #[cfg(feature = "rocksdb")] + "rocksdb", + #[cfg(feature = "lmdb")] + "lmdb", + ]; return Err(PyValueError::new_err(format!( "unsupported backend '{other}', expected {}", expected.join(" or ") @@ -235,6 +256,14 @@ impl PythonDatabase { Ok(()) } + + pub fn ddl(&mut self, sql: &str) -> PyResult<()> { + self.inner.ddl(sql).map_err(to_py_err) + } + + pub fn analyze(&mut self, table_name: &str) -> PyResult<()> { + self.inner.analyze(table_name).map_err(to_py_err) + } } #[pyclass(name = "ResultIter", unsendable)] @@ -269,7 +298,7 @@ impl PythonResultIter { pub fn schema(&self, py: Python<'_>) -> PyResult> { let iter = self.inner_ref()?; - schema_to_python(py, iter.schema()) + iter.schema(|schema| schema_to_python(py, schema)) } pub fn rows(&mut self, py: Python<'_>) -> PyResult> { @@ -377,8 +406,8 @@ mod tests { c_str!( r#" db = kite_sql.Database.in_memory() if backend == "memory" else kite_sql.Database(db_path, backend) -db.execute("drop table if exists my_struct") -db.execute("create table my_struct (c1 int primary key, c2 int)") +db.ddl("drop table if exists my_struct") +db.ddl("create table my_struct (c1 int primary key, c2 int)") db.execute("insert into my_struct values(0, 0), (1, 1)") iter_obj = db.run("select * from my_struct") @@ -405,7 +434,7 @@ while row is not None: stream.finish() assert streamed == [[0, 0], [1, 11]] -db.execute("drop table my_struct") +db.ddl("drop table my_struct") "# ), )?; @@ -423,8 +452,8 @@ db.execute("drop table my_struct") c_str!( r#" db = kite_sql.Database.in_memory() if backend == "memory" else kite_sql.Database(db_path, backend) -db.execute("drop table if exists t1") -db.execute("create table t1(id int primary key, c1 int, c2 int)") +db.ddl("drop table if exists t1") +db.ddl("create table t1(id int primary key, c1 int, c2 int)") for i in range(2000): id_v = i * 3 @@ -432,10 +461,10 @@ for i in range(2000): c2_v = id_v + 2 db.execute(f"insert into t1 values({id_v}, {c1_v}, {c2_v})") -db.execute("create unique index u_c1_index on t1 (c1)") -db.execute("create index c2_index on t1 (c2)") -db.execute("create index p_index on t1 (c1, c2)") -db.execute("analyze table t1") +db.ddl("create unique index u_c1_index on t1 (c1)") +db.ddl("create index c2_index on t1 (c2)") +db.ddl("create index p_index on t1 (c1, c2)") +db.analyze("t1") def row_vals(row): ints = row["values"] @@ -473,7 +502,43 @@ db.execute("delete from t1 where c1 = 7") after_delete = db.run("select * from t1 where c2 = 123456").rows() assert len(after_delete) == 0 -db.execute("drop table t1") +db.ddl("drop table t1") +"# + ), + )?; + Ok(()) + }) + } + + #[test] + fn test_python_mutations_use_explicit_api() -> PyResult<()> { + Python::with_gil(|py| { + let module = register_module(py)?; + run_script_on_all_backends( + py, + &module, + c_str!( + r#" +db = kite_sql.Database.in_memory() if backend == "memory" else kite_sql.Database(db_path, backend) + +try: + db.execute("create table explicit_api(id int primary key)") + raise AssertionError("expected execute to reject DDL") +except RuntimeError as exc: + assert "Database::ddl" in str(exc) + +db.ddl("create table explicit_api(id int primary key)") +for i in range(200): + db.execute(f"insert into explicit_api values ({i})") + +try: + db.execute("analyze table explicit_api") + raise AssertionError("expected execute to reject ANALYZE") +except RuntimeError as exc: + assert "Database::analyze" in str(exc) + +db.analyze("explicit_api") +db.ddl("drop table explicit_api") "# ), )?; diff --git a/src/serdes/boolean.rs b/src/serdes/boolean.rs index fcbb3b8f..d7447a27 100644 --- a/src/serdes/boolean.rs +++ b/src/serdes/boolean.rs @@ -18,20 +18,22 @@ use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for bool { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - if *self { 1u8 } else { 0u8 }.encode(writer, is_direct, reference_tables) + if *self { 1u8 } else { 0u8 }.encode(writer, is_direct, reference_tables, arena) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - Ok(u8::decode(reader, drive, reference_tables)? == 1u8) + Ok(u8::decode(reader, drive, reference_tables, arena)? == 1u8) } } diff --git a/src/serdes/bound.rs b/src/serdes/bound.rs index 825f2b3d..2faa8513 100644 --- a/src/serdes/bound.rs +++ b/src/serdes/bound.rs @@ -22,22 +22,23 @@ impl ReferenceSerialization for Bound where V: ReferenceSerialization, { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { match self { Bound::Included(v) => { writer.write_all(&[0])?; - v.encode(writer, is_direct, reference_tables)?; + v.encode(writer, is_direct, reference_tables, arena)?; } Bound::Excluded(v) => { writer.write_all(&[1])?; - v.encode(writer, is_direct, reference_tables)?; + v.encode(writer, is_direct, reference_tables, arena)?; } Bound::Unbounded => { writer.write_all(&[2])?; @@ -47,17 +48,18 @@ where Ok(()) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { let mut type_bytes = [0u8; 1]; reader.read_exact(&mut type_bytes)?; Ok(match type_bytes[0] { - 0 => Bound::Included(V::decode(reader, drive, reference_tables)?), - 1 => Bound::Excluded(V::decode(reader, drive, reference_tables)?), + 0 => Bound::Included(V::decode(reader, drive, reference_tables, arena)?), + 1 => Bound::Excluded(V::decode(reader, drive, reference_tables, arena)?), 2 => Bound::Unbounded, _ => unreachable!(), }) diff --git a/src/serdes/btree_map.rs b/src/serdes/btree_map.rs index 7689ea93..149114e2 100644 --- a/src/serdes/btree_map.rs +++ b/src/serdes/btree_map.rs @@ -23,30 +23,34 @@ where K: ReferenceSerialization + Ord, V: ReferenceSerialization, { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self.len().encode(writer, is_direct, reference_tables)?; + self.len() + .encode(writer, is_direct, reference_tables, arena)?; for (key, value) in self.iter() { - key.encode(writer, is_direct, reference_tables)?; - value.encode(writer, is_direct, reference_tables)?; + key.encode(writer, is_direct, reference_tables, arena)?; + value.encode(writer, is_direct, reference_tables, arena)?; } Ok(()) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - let len = ::decode(reader, drive, reference_tables)?; + let len = + ::decode(reader, drive, reference_tables, arena)?; let mut btree_map = BTreeMap::new(); for _ in 0..len { - let key = K::decode(reader, drive, reference_tables)?; - let value = V::decode(reader, drive, reference_tables)?; + let key = K::decode(reader, drive, reference_tables, arena)?; + let value = V::decode(reader, drive, reference_tables, arena)?; btree_map.insert(key, value); } Ok(btree_map) diff --git a/src/serdes/char.rs b/src/serdes/char.rs index c6dcc103..05339336 100644 --- a/src/serdes/char.rs +++ b/src/serdes/char.rs @@ -18,11 +18,12 @@ use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for char { - fn encode( + fn encode( &self, writer: &mut W, _: bool, _: &mut ReferenceTables, + _: &A, ) -> Result<(), DatabaseError> { let mut buf = [0u8; 2]; self.encode_utf8(&mut buf); @@ -30,10 +31,11 @@ impl ReferenceSerialization for char { Ok(writer.write_all(&buf)?) } - fn decode( + fn decode( reader: &mut R, _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, + _: &mut A, ) -> Result { let mut buf = [0u8; 2]; reader.read_exact(&mut buf)?; diff --git a/src/serdes/char_length_units.rs b/src/serdes/char_length_units.rs index b0b4036a..d9f7725a 100644 --- a/src/serdes/char_length_units.rs +++ b/src/serdes/char_length_units.rs @@ -19,25 +19,27 @@ use crate::types::CharLengthUnits; use std::io::{Read, Write}; impl ReferenceSerialization for CharLengthUnits { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { match self { CharLengthUnits::Characters => 0u8, CharLengthUnits::Octets => 1u8, } - .encode(writer, is_direct, reference_tables)?; + .encode(writer, is_direct, reference_tables, arena)?; Ok(()) } - fn decode( + fn decode( reader: &mut R, _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, + _: &mut A, ) -> Result { let mut one_byte = [0u8; 1]; reader.read_exact(&mut one_byte)?; @@ -62,18 +64,29 @@ pub(crate) mod test { fn test_serialization() -> Result<(), DatabaseError> { let mut cursor = Cursor::new(Vec::new()); let mut reference_tables = ReferenceTables::new(); + let mut arena = crate::planner::TableArena::default(); - CharLengthUnits::Characters.encode(&mut cursor, false, &mut reference_tables)?; + CharLengthUnits::Characters.encode(&mut cursor, false, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; assert_eq!( - CharLengthUnits::decode::(&mut cursor, None, &reference_tables)?, + CharLengthUnits::decode::( + &mut cursor, + None, + &reference_tables, + &mut arena, + )?, CharLengthUnits::Characters ); cursor.seek(SeekFrom::Start(0))?; - CharLengthUnits::Octets.encode(&mut cursor, false, &mut reference_tables)?; + CharLengthUnits::Octets.encode(&mut cursor, false, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; assert_eq!( - CharLengthUnits::decode::(&mut cursor, None, &reference_tables)?, + CharLengthUnits::decode::( + &mut cursor, + None, + &reference_tables, + &mut arena, + )?, CharLengthUnits::Octets ); diff --git a/src/serdes/column.rs b/src/serdes/column.rs index d2d66685..71092fa4 100644 --- a/src/serdes/column.rs +++ b/src/serdes/column.rs @@ -12,91 +12,62 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, ColumnSummary}; +use crate::catalog::{ColumnCatalog, ColumnRef, ColumnRelation}; use crate::errors::DatabaseError; +use crate::planner::MetaArena; use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; use crate::storage::Transaction; use crate::types::ColumnId; use std::io::{Read, Write}; impl ReferenceSerialization for ColumnRef { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self.summary().encode(writer, is_direct, reference_tables)?; - self.in_join() - .then(|| self.nullable()) - .encode(writer, is_direct, reference_tables)?; - - if is_direct - || !matches!( - self.summary().relation, - ColumnRelation::Table { is_temp: false, .. } - ) - { - self.nullable() - .encode(writer, is_direct, reference_tables)?; - self.desc().encode(writer, is_direct, reference_tables)?; - } - - Ok(()) + arena + .column(*self) + .encode(writer, is_direct, reference_tables, arena) } - fn decode( + fn decode( reader: &mut R, drive: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - let summary = ColumnSummary::decode(reader, drive, reference_tables)?; - let nullable_for_join = Option::::decode(reader, drive, reference_tables)?; + let column = ColumnCatalog::decode(reader, drive, reference_tables, arena)?; - if let ( - ColumnRelation::Table { - column_id, - table_name, - is_temp: false, - }, - Some((transaction, table_cache)), - ) = ( - &summary.relation, - drive.and_then(ReferenceDecodeContext::drive), - ) { - let table = transaction - .table(table_cache, table_name.clone())? - .ok_or(DatabaseError::TableNotFound)?; - let column = table - .get_column_by_id(column_id) - .ok_or(DatabaseError::invalid_column(format!( - "column id: {column_id} not found" - )))?; - Ok(nullable_for_join - .and_then(|nullable| column.nullable_for_join(nullable)) - .unwrap_or_else(|| column.clone())) - } else { - let mut nullable = bool::decode(reader, drive, reference_tables)?; - let desc = ColumnDesc::decode(reader, drive, reference_tables)?; - let mut in_join = false; - if let Some(nullable_for_join) = nullable_for_join { - in_join = true; - nullable = nullable_for_join; + if let ColumnRelation::Table { + column_id, + table_name, + is_temp: false, + } = &column.summary().relation + { + if let Some((_, table_cache)) = drive.and_then(ReferenceDecodeContext::drive) { + if let Some(column_ref) = table_cache + .get(table_name) + .and_then(|table| table.get_column_by_id(column_id)) + { + return Ok(column_ref); + } } - - Ok(Self::from(ColumnCatalog::direct_new( - summary, nullable, desc, in_join, - ))) } + + Ok(arena.alloc_column(column)) } } impl ReferenceSerialization for ColumnRelation { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { match self { ColumnRelation::None => { @@ -108,13 +79,14 @@ impl ReferenceSerialization for ColumnRelation { is_temp, } => { writer.write_all(&[1])?; - column_id.encode(writer, is_direct, reference_tables)?; - is_temp.encode(writer, is_direct, reference_tables)?; + column_id.encode(writer, is_direct, reference_tables, arena)?; + is_temp.encode(writer, is_direct, reference_tables, arena)?; reference_tables.push_or_replace(table_name).encode( writer, is_direct, reference_tables, + arena, )?; } } @@ -122,10 +94,11 @@ impl ReferenceSerialization for ColumnRelation { Ok(()) } - fn decode( + fn decode( reader: &mut R, drive: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { let mut type_bytes = [0u8; 1]; reader.read_exact(&mut type_bytes)?; @@ -133,13 +106,14 @@ impl ReferenceSerialization for ColumnRelation { Ok(match type_bytes[0] { 0 => ColumnRelation::None, 1 => { - let column_id = ColumnId::decode(reader, drive, reference_tables)?; - let is_temp = bool::decode(reader, drive, reference_tables)?; + let column_id = ColumnId::decode(reader, drive, reference_tables, arena)?; + let is_temp = bool::decode(reader, drive, reference_tables, arena)?; let table_name = reference_tables .get(::decode( reader, drive, reference_tables, + arena, )?) .clone(); @@ -162,15 +136,14 @@ pub(crate) mod test { use crate::db::test::build_table; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; + use crate::planner::{PlanArena, TableArenaCell}; use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; - use crate::storage::rocksdb::{RocksStorage, RocksTransaction}; - use crate::storage::{StatisticsMetaCache, Storage, Transaction}; + use crate::storage::rocksdb::RocksStorage; + use crate::storage::rocksdb::RocksTransaction; + use crate::storage::{Storage, Transaction}; use crate::types::value::DataValue; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; - use std::hash::RandomState; use std::io::{Cursor, Seek, SeekFrom}; - use std::sync::Arc; use tempfile::TempDir; use ulid::Ulid; @@ -179,68 +152,66 @@ pub(crate) mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; let mut transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let meta_cache = StatisticsMetaCache::new(4, 1, RandomState::new())?; + let mut table_cache = crate::storage::TableCache::default(); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); let table_name: TableName = "t1".to_string().into(); - build_table(&table_cache, &mut transaction)?; + build_table(&mut table_cache, &mut transaction, &mut plan_arena)?; + let mut plan_arena = PlanArena::new(&table_arena); let mut cursor = Cursor::new(Vec::new()); let mut reference_tables = ReferenceTables::new(); - let c3_column_id = { + let ref_column = { let table = transaction .table(&table_cache, "t1".to_string().into())? .unwrap(); - *table.get_column_id_by_name("c3").unwrap() + table.get_column_by_name("c3").unwrap() }; { - let ref_column = ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c3".to_string(), - relation: ColumnRelation::Table { - column_id: c3_column_id, - table_name: table_name.clone(), - is_temp: false, - }, - }, - false, - ColumnDesc::new(LogicalType::Integer, None, false, None)?, - false, - )); - - ref_column.encode(&mut cursor, false, &mut reference_tables)?; + ref_column.encode(&mut cursor, false, &mut reference_tables, &plan_arena)?; cursor.seek(SeekFrom::Start(0))?; assert_eq!( { let context = ReferenceDecodeContext::new(Some((&transaction, &table_cache))); - ColumnRef::decode::>>( + ColumnRef::decode::>, _>( &mut cursor, Some(&context), &reference_tables, + &mut plan_arena, )? }, ref_column ); cursor.seek(SeekFrom::Start(0))?; - transaction.drop_column(&table_cache, &meta_cache, &table_name, "c3")?; + let mut table_codec = crate::storage::table_codec::TableCodec::default(); + let table = + transaction.drop_column(&mut table_codec, &mut plan_arena, &table_name, "c3")?; + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); + plan_arena = PlanArena::new(&table_arena); let context = ReferenceDecodeContext::new(Some((&transaction, &table_cache))); - assert!(ColumnRef::decode::>>( - &mut cursor, - Some(&context), - &reference_tables - ) - .is_err()); + assert_eq!( + ColumnRef::decode::>, _>( + &mut cursor, + Some(&context), + &reference_tables, + &mut plan_arena, + )?, + ref_column + ); + let table = transaction + .table(&table_cache, table_name.clone())? + .expect("table should still exist after dropping one column"); + assert!(table.get_column_id_by_name("c3").is_none()); cursor.seek(SeekFrom::Start(0))?; } { - let not_ref_column = ColumnRef::from(ColumnCatalog::direct_new( - ColumnSummary { - name: "c3".to_string(), - relation: ColumnRelation::None, - }, + let not_ref_column = plan_arena.alloc_column(ColumnCatalog::new( + "c3".to_string(), false, ColumnDesc::new( LogicalType::Integer, @@ -248,18 +219,19 @@ pub(crate) mod test { false, Some(ScalarExpression::Constant(DataValue::UInt64(42))), )?, - false, )); - not_ref_column.encode(&mut cursor, false, &mut reference_tables)?; + not_ref_column.encode(&mut cursor, false, &mut reference_tables, &plan_arena)?; cursor.seek(SeekFrom::Start(0))?; + let decoded = ColumnRef::decode::>, _>( + &mut cursor, + None, + &reference_tables, + &mut plan_arena, + )?; assert_eq!( - ColumnRef::decode::>>( - &mut cursor, - None, - &reference_tables - )?, - not_ref_column + plan_arena.column(decoded), + plan_arena.column(not_ref_column) ); } @@ -270,6 +242,7 @@ pub(crate) mod test { fn test_column_summary_serialization() -> Result<(), DatabaseError> { let mut cursor = Cursor::new(Vec::new()); let mut reference_tables = ReferenceTables::new(); + let mut arena = crate::planner::TableArena::default(); let summary = ColumnSummary { name: "c1".to_string(), relation: ColumnRelation::Table { @@ -278,14 +251,15 @@ pub(crate) mod test { is_temp: false, }, }; - summary.encode(&mut cursor, false, &mut reference_tables)?; + summary.encode(&mut cursor, false, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; assert_eq!( - ColumnSummary::decode::>>( + ColumnSummary::decode::>, _>( &mut cursor, None, - &reference_tables + &reference_tables, + &mut arena, )?, summary ); @@ -297,14 +271,16 @@ pub(crate) mod test { fn test_column_relation_serialization() -> Result<(), DatabaseError> { let mut cursor = Cursor::new(Vec::new()); let mut reference_tables = ReferenceTables::new(); + let mut arena = crate::planner::TableArena::default(); let none_relation = ColumnRelation::None; - none_relation.encode(&mut cursor, false, &mut reference_tables)?; + none_relation.encode(&mut cursor, false, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; - let decode_relation = ColumnRelation::decode::>>( + let decode_relation = ColumnRelation::decode::>, _>( &mut cursor, None, &reference_tables, + &mut arena, )?; assert_eq!(none_relation, decode_relation); cursor.seek(SeekFrom::Start(0))?; @@ -313,13 +289,14 @@ pub(crate) mod test { table_name: "t1".to_string().into(), is_temp: false, }; - table_relation.encode(&mut cursor, false, &mut reference_tables)?; + table_relation.encode(&mut cursor, false, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; - let decode_relation = ColumnRelation::decode::>>( + let decode_relation = ColumnRelation::decode::>, _>( &mut cursor, None, &reference_tables, + &mut arena, )?; assert_eq!(table_relation, decode_relation); @@ -330,19 +307,21 @@ pub(crate) mod test { fn test_column_desc_serialization() -> Result<(), DatabaseError> { let mut cursor = Cursor::new(Vec::new()); let mut reference_tables = ReferenceTables::new(); + let mut arena = crate::planner::TableArena::default(); let desc = ColumnDesc::new( LogicalType::Integer, None, false, Some(ScalarExpression::Constant(DataValue::UInt64(42))), )?; - desc.encode(&mut cursor, false, &mut reference_tables)?; + desc.encode(&mut cursor, false, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; - let decode_desc = ColumnDesc::decode::>>( + let decode_desc = ColumnDesc::decode::>, _>( &mut cursor, None, &reference_tables, + &mut arena, )?; assert_eq!(desc, decode_desc); diff --git a/src/serdes/data_value.rs b/src/serdes/data_value.rs index 1ea45dba..e54641d8 100644 --- a/src/serdes/data_value.rs +++ b/src/serdes/data_value.rs @@ -12,60 +12,456 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::implement_serialization_by_bincode; +use crate::errors::DatabaseError; +use crate::serdes::{ReferenceSerialization, ReferenceTables}; +use crate::storage::Transaction; use crate::types::value::DataValue; +use crate::types::value::Utf8Type; +use crate::types::CharLengthUnits; +use ordered_float::OrderedFloat; +#[cfg(feature = "decimal")] +use rust_decimal::Decimal; +use std::io::{Read, Write}; -implement_serialization_by_bincode!(DataValue); +const TAG_NULL: u8 = 0; +const TAG_BOOLEAN: u8 = 1; +const TAG_FLOAT32: u8 = 2; +const TAG_FLOAT64: u8 = 3; +const TAG_INT8: u8 = 4; +const TAG_INT16: u8 = 5; +const TAG_INT32: u8 = 6; +const TAG_INT64: u8 = 7; +const TAG_UINT8: u8 = 8; +const TAG_UINT16: u8 = 9; +const TAG_UINT32: u8 = 10; +const TAG_UINT64: u8 = 11; +const TAG_UTF8: u8 = 12; +const TAG_DATE32: u8 = 13; +const TAG_DATE64: u8 = 14; +const TAG_TIME32: u8 = 15; +const TAG_TIME64: u8 = 16; +const TAG_DECIMAL: u8 = 17; +const TAG_TUPLE: u8 = 18; + +impl ReferenceSerialization for Utf8Type { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + arena: &A, + ) -> Result<(), DatabaseError> { + match self { + Utf8Type::Variable(len) => { + 0u8.encode(writer, is_direct, reference_tables, arena)?; + len.encode(writer, is_direct, reference_tables, arena) + } + Utf8Type::Fixed(len) => { + 1u8.encode(writer, is_direct, reference_tables, arena)?; + len.encode(writer, is_direct, reference_tables, arena) + } + } + } + + fn decode( + reader: &mut R, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + arena: &mut A, + ) -> Result { + match u8::decode(reader, drive, reference_tables, arena)? { + 0 => Ok(Utf8Type::Variable(Option::::decode( + reader, + drive, + reference_tables, + arena, + )?)), + 1 => Ok(Utf8Type::Fixed(u32::decode( + reader, + drive, + reference_tables, + arena, + )?)), + tag => Err(DatabaseError::InvalidValue(format!( + "invalid utf8 type tag: {tag}" + ))), + } + } +} + +impl DataValue { + pub(crate) fn encode_reference_value( + &self, + writer: &mut W, + ) -> Result<(), DatabaseError> { + match self { + DataValue::Null => write_u8(writer, TAG_NULL), + DataValue::Boolean(value) => { + write_u8(writer, TAG_BOOLEAN)?; + write_bool(writer, *value) + } + DataValue::Float32(value) => { + write_u8(writer, TAG_FLOAT32)?; + write_f32(writer, value.0) + } + DataValue::Float64(value) => { + write_u8(writer, TAG_FLOAT64)?; + write_f64(writer, value.0) + } + DataValue::Int8(value) => { + write_u8(writer, TAG_INT8)?; + write_i8(writer, *value) + } + DataValue::Int16(value) => { + write_u8(writer, TAG_INT16)?; + write_i16(writer, *value) + } + DataValue::Int32(value) => { + write_u8(writer, TAG_INT32)?; + write_i32(writer, *value) + } + DataValue::Int64(value) => { + write_u8(writer, TAG_INT64)?; + write_i64(writer, *value) + } + DataValue::UInt8(value) => { + write_u8(writer, TAG_UINT8)?; + write_u8(writer, *value) + } + DataValue::UInt16(value) => { + write_u8(writer, TAG_UINT16)?; + write_u16(writer, *value) + } + DataValue::UInt32(value) => { + write_u8(writer, TAG_UINT32)?; + write_u32(writer, *value) + } + DataValue::UInt64(value) => { + write_u8(writer, TAG_UINT64)?; + write_u64(writer, *value) + } + DataValue::Utf8 { value, ty, unit } => { + write_u8(writer, TAG_UTF8)?; + write_string(writer, value)?; + write_utf8_type(writer, ty)?; + write_char_length_units(writer, *unit) + } + DataValue::Date32(value) => { + write_u8(writer, TAG_DATE32)?; + write_i32(writer, *value) + } + DataValue::Date64(value) => { + write_u8(writer, TAG_DATE64)?; + write_i64(writer, *value) + } + DataValue::Time32(value, precision) => { + write_u8(writer, TAG_TIME32)?; + write_u32(writer, *value)?; + write_u64(writer, *precision) + } + DataValue::Time64(value, precision, with_tz) => { + write_u8(writer, TAG_TIME64)?; + write_i64(writer, *value)?; + write_u64(writer, *precision)?; + write_bool(writer, *with_tz) + } + #[cfg(feature = "decimal")] + DataValue::Decimal(value) => { + write_u8(writer, TAG_DECIMAL)?; + writer.write_all(&value.serialize())?; + Ok(()) + } + DataValue::Tuple(values, is_upper) => { + write_u8(writer, TAG_TUPLE)?; + write_len(writer, values.len())?; + for value in values { + value.encode_reference_value(writer)?; + } + write_bool(writer, *is_upper) + } + } + } + + pub(crate) fn decode_reference_value(reader: &mut R) -> Result { + match read_u8(reader)? { + TAG_NULL => Ok(DataValue::Null), + TAG_BOOLEAN => Ok(DataValue::Boolean(read_bool(reader)?)), + TAG_FLOAT32 => Ok(DataValue::Float32(OrderedFloat(read_f32(reader)?))), + TAG_FLOAT64 => Ok(DataValue::Float64(OrderedFloat(read_f64(reader)?))), + TAG_INT8 => Ok(DataValue::Int8(read_i8(reader)?)), + TAG_INT16 => Ok(DataValue::Int16(read_i16(reader)?)), + TAG_INT32 => Ok(DataValue::Int32(read_i32(reader)?)), + TAG_INT64 => Ok(DataValue::Int64(read_i64(reader)?)), + TAG_UINT8 => Ok(DataValue::UInt8(read_u8(reader)?)), + TAG_UINT16 => Ok(DataValue::UInt16(read_u16(reader)?)), + TAG_UINT32 => Ok(DataValue::UInt32(read_u32(reader)?)), + TAG_UINT64 => Ok(DataValue::UInt64(read_u64(reader)?)), + TAG_UTF8 => Ok(DataValue::Utf8 { + value: read_string(reader)?, + ty: read_utf8_type(reader)?, + unit: read_char_length_units(reader)?, + }), + TAG_DATE32 => Ok(DataValue::Date32(read_i32(reader)?)), + TAG_DATE64 => Ok(DataValue::Date64(read_i64(reader)?)), + TAG_TIME32 => Ok(DataValue::Time32(read_u32(reader)?, read_u64(reader)?)), + TAG_TIME64 => Ok(DataValue::Time64( + read_i64(reader)?, + read_u64(reader)?, + read_bool(reader)?, + )), + #[cfg(feature = "decimal")] + TAG_DECIMAL => { + let mut bytes = [0; 16]; + reader.read_exact(&mut bytes)?; + Ok(DataValue::Decimal(Decimal::deserialize(bytes))) + } + #[cfg(not(feature = "decimal"))] + TAG_DECIMAL => { + let mut bytes = [0; 16]; + reader.read_exact(&mut bytes)?; + Err(DatabaseError::UnsupportedStmt( + "DECIMAL requires the `decimal` feature".to_string(), + )) + } + TAG_TUPLE => { + let len = read_len(reader)?; + let mut values = Vec::with_capacity(len); + for _ in 0..len { + values.push(DataValue::decode_reference_value(reader)?); + } + Ok(DataValue::Tuple(values, read_bool(reader)?)) + } + tag => Err(DatabaseError::InvalidValue(format!( + "invalid data value tag: {tag}" + ))), + } + } +} + +impl ReferenceSerialization for DataValue { + fn encode( + &self, + writer: &mut W, + _: bool, + _: &mut ReferenceTables, + _: &A, + ) -> Result<(), DatabaseError> { + self.encode_reference_value(writer) + } + + fn decode( + reader: &mut R, + _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, + _: &ReferenceTables, + _: &mut A, + ) -> Result { + DataValue::decode_reference_value(reader) + } +} + +fn write_u8(writer: &mut W, value: u8) -> Result<(), DatabaseError> { + writer.write_all(&[value])?; + Ok(()) +} + +fn read_u8(reader: &mut R) -> Result { + let mut bytes = [0]; + reader.read_exact(&mut bytes)?; + Ok(bytes[0]) +} + +fn write_bool(writer: &mut W, value: bool) -> Result<(), DatabaseError> { + write_u8(writer, u8::from(value)) +} + +fn read_bool(reader: &mut R) -> Result { + match read_u8(reader)? { + 0 => Ok(false), + 1 => Ok(true), + value => Err(DatabaseError::InvalidValue(format!( + "invalid bool value: {value}" + ))), + } +} + +fn write_i8(writer: &mut W, value: i8) -> Result<(), DatabaseError> { + writer.write_all(&value.to_le_bytes())?; + Ok(()) +} + +fn read_i8(reader: &mut R) -> Result { + Ok(i8::from_le_bytes([read_u8(reader)?])) +} + +macro_rules! implement_raw_num { + ($write_name:ident, $read_name:ident, $ty:ty) => { + fn $write_name(writer: &mut W, value: $ty) -> Result<(), DatabaseError> { + writer.write_all(&value.to_le_bytes())?; + Ok(()) + } + + fn $read_name(reader: &mut R) -> Result<$ty, DatabaseError> { + let mut bytes = [0; std::mem::size_of::<$ty>()]; + reader.read_exact(&mut bytes)?; + Ok(<$ty>::from_le_bytes(bytes)) + } + }; +} + +implement_raw_num!(write_i16, read_i16, i16); +implement_raw_num!(write_i32, read_i32, i32); +implement_raw_num!(write_i64, read_i64, i64); +implement_raw_num!(write_u16, read_u16, u16); +implement_raw_num!(write_u32, read_u32, u32); +implement_raw_num!(write_u64, read_u64, u64); +implement_raw_num!(write_f32, read_f32, f32); +implement_raw_num!(write_f64, read_f64, f64); + +fn write_len(writer: &mut W, len: usize) -> Result<(), DatabaseError> { + write_u32(writer, len.try_into()?) +} + +fn read_len(reader: &mut R) -> Result { + Ok(read_u32(reader)? as usize) +} + +fn write_string(writer: &mut W, value: &str) -> Result<(), DatabaseError> { + write_len(writer, value.len())?; + writer.write_all(value.as_bytes())?; + Ok(()) +} + +fn read_string(reader: &mut R) -> Result { + let len = read_len(reader)?; + let mut bytes = vec![0; len]; + reader.read_exact(&mut bytes)?; + Ok(String::from_utf8(bytes)?) +} + +fn write_utf8_type(writer: &mut W, value: &Utf8Type) -> Result<(), DatabaseError> { + match value { + Utf8Type::Variable(len) => { + write_u8(writer, 0)?; + match len { + None => write_u8(writer, 0), + Some(len) => { + write_u8(writer, 1)?; + write_u32(writer, *len) + } + } + } + Utf8Type::Fixed(len) => { + write_u8(writer, 1)?; + write_u32(writer, *len) + } + } +} + +fn read_utf8_type(reader: &mut R) -> Result { + match read_u8(reader)? { + 0 => match read_u8(reader)? { + 0 => Ok(Utf8Type::Variable(None)), + 1 => Ok(Utf8Type::Variable(Some(read_u32(reader)?))), + tag => Err(DatabaseError::InvalidValue(format!( + "invalid option tag: {tag}" + ))), + }, + 1 => Ok(Utf8Type::Fixed(read_u32(reader)?)), + tag => Err(DatabaseError::InvalidValue(format!( + "invalid utf8 type tag: {tag}" + ))), + } +} + +fn write_char_length_units( + writer: &mut W, + value: CharLengthUnits, +) -> Result<(), DatabaseError> { + write_u8( + writer, + match value { + CharLengthUnits::Characters => 0, + CharLengthUnits::Octets => 1, + }, + ) +} + +fn read_char_length_units(reader: &mut R) -> Result { + match read_u8(reader)? { + 0 => Ok(CharLengthUnits::Characters), + 1 => Ok(CharLengthUnits::Octets), + tag => Err(DatabaseError::InvalidValue(format!( + "invalid char length units tag: {tag}" + ))), + } +} #[cfg(all(test, not(target_arch = "wasm32")))] pub(crate) mod test { use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; use crate::storage::rocksdb::RocksTransaction; - use crate::types::value::DataValue; + use crate::types::value::{DataValue, Utf8Type}; + use crate::types::CharLengthUnits; + use ordered_float::OrderedFloat; + #[cfg(feature = "decimal")] + use rust_decimal::Decimal; use std::io::{Cursor, Seek, SeekFrom}; #[test] fn test_serialization() -> Result<(), DatabaseError> { - let source_0 = DataValue::Null; - let source_1 = DataValue::Int32(32); - let source_2 = DataValue::Null; - let source_3 = DataValue::Null; - let source_4 = DataValue::Null; - let source_5 = DataValue::Tuple(vec![DataValue::Null, DataValue::Int32(42)], false); + let sources = vec![ + DataValue::Null, + DataValue::Boolean(true), + DataValue::Float32(OrderedFloat(32.5)), + DataValue::Float64(OrderedFloat(64.5)), + DataValue::Int8(-8), + DataValue::Int16(-16), + DataValue::Int32(32), + DataValue::Int64(-64), + DataValue::UInt8(8), + DataValue::UInt16(16), + DataValue::UInt32(32), + DataValue::UInt64(64), + DataValue::Utf8 { + value: "hello".to_string(), + ty: Utf8Type::Variable(Some(16)), + unit: CharLengthUnits::Characters, + }, + DataValue::Utf8 { + value: "octets".to_string(), + ty: Utf8Type::Fixed(8), + unit: CharLengthUnits::Octets, + }, + DataValue::Date32(12), + DataValue::Date64(34), + DataValue::Time32(56, 3), + DataValue::Time64(78, 6, true), + #[cfg(feature = "decimal")] + DataValue::Decimal(Decimal::new(12345, 2)), + DataValue::Tuple(vec![DataValue::Null, DataValue::Int32(42)], false), + ]; let mut reference_tables = ReferenceTables::new(); let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); + let mut arena = crate::planner::TableArena::default(); - source_0.encode(&mut cursor, false, &mut reference_tables)?; - source_1.encode(&mut cursor, false, &mut reference_tables)?; - source_2.encode(&mut cursor, false, &mut reference_tables)?; - source_3.encode(&mut cursor, false, &mut reference_tables)?; - source_4.encode(&mut cursor, false, &mut reference_tables)?; - source_5.encode(&mut cursor, false, &mut reference_tables)?; + for source in &sources { + source.encode(&mut cursor, false, &mut reference_tables, &arena)?; + } cursor.seek(SeekFrom::Start(0))?; - let decoded_0 = - DataValue::decode::(&mut cursor, None, &reference_tables).unwrap(); - let decoded_1 = - DataValue::decode::(&mut cursor, None, &reference_tables).unwrap(); - let decoded_2 = - DataValue::decode::(&mut cursor, None, &reference_tables).unwrap(); - let decoded_3 = - DataValue::decode::(&mut cursor, None, &reference_tables).unwrap(); - let decoded_4 = - DataValue::decode::(&mut cursor, None, &reference_tables).unwrap(); - let decoded_5 = - DataValue::decode::(&mut cursor, None, &reference_tables).unwrap(); - - assert_eq!(source_0, decoded_0); - assert_eq!(source_1, decoded_1); - assert_eq!(source_2, decoded_2); - assert_eq!(source_3, decoded_3); - assert_eq!(source_4, decoded_4); - assert_eq!(source_5, decoded_5); + for source in sources { + let decoded = DataValue::decode::( + &mut cursor, + None, + &reference_tables, + &mut arena, + )?; + assert_eq!(source, decoded); + } Ok(()) } diff --git a/src/serdes/evaluator.rs b/src/serdes/evaluator.rs index 42b84310..4f47af14 100644 --- a/src/serdes/evaluator.rs +++ b/src/serdes/evaluator.rs @@ -13,79 +13,366 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::expression::{BinaryOperator, UnaryOperator}; use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; use crate::storage::Transaction; use crate::types::evaluator::{ - binary_create, cast_create, unary_create, BinaryEvaluatorBox, CastEvaluatorBox, - UnaryEvaluatorBox, + BinaryEvaluatorParams, BinaryEvaluatorRef, CastEvaluatorParams, CastEvaluatorRef, + UnaryEvaluatorRef, }; -use crate::types::LogicalType; -use std::borrow::Cow; use std::io::{Read, Write}; -impl ReferenceSerialization for UnaryEvaluatorBox { - fn encode( +impl ReferenceSerialization for BinaryEvaluatorParams { + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self.ty.encode(writer, is_direct, reference_tables)?; - self.op.encode(writer, is_direct, reference_tables) + match self { + BinaryEvaluatorParams::Unit => 0u8.encode(writer, is_direct, reference_tables, arena), + BinaryEvaluatorParams::Like { escape_char } => { + 1u8.encode(writer, is_direct, reference_tables, arena)?; + escape_char.encode(writer, is_direct, reference_tables, arena) + } + } } - fn decode( + fn decode( reader: &mut R, context: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - let ty = LogicalType::decode(reader, context, reference_tables)?; - let op = UnaryOperator::decode(reader, context, reference_tables)?; - unary_create(Cow::Owned(ty), op) + Ok( + match u8::decode(reader, context, reference_tables, arena)? { + 0 => BinaryEvaluatorParams::Unit, + 1 => BinaryEvaluatorParams::Like { + escape_char: Option::::decode(reader, context, reference_tables, arena)?, + }, + _ => unreachable!(), + }, + ) } } -impl ReferenceSerialization for BinaryEvaluatorBox { - fn encode( +impl ReferenceSerialization for CastEvaluatorParams { + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self.ty.encode(writer, is_direct, reference_tables)?; - self.op.encode(writer, is_direct, reference_tables) + match self { + CastEvaluatorParams::Identity => 0u8.encode(writer, is_direct, reference_tables, arena), + CastEvaluatorParams::Unit => 1u8.encode(writer, is_direct, reference_tables, arena), + CastEvaluatorParams::String { len, unit } => { + 2u8.encode(writer, is_direct, reference_tables, arena)?; + len.encode(writer, is_direct, reference_tables, arena)?; + unit.encode(writer, is_direct, reference_tables, arena) + } + #[cfg(feature = "decimal")] + CastEvaluatorParams::Decimal { precision, scale } => { + 3u8.encode(writer, is_direct, reference_tables, arena)?; + precision.encode(writer, is_direct, reference_tables, arena)?; + scale.encode(writer, is_direct, reference_tables, arena) + } + CastEvaluatorParams::Precision { precision } => { + 4u8.encode(writer, is_direct, reference_tables, arena)?; + precision.encode(writer, is_direct, reference_tables, arena) + } + CastEvaluatorParams::Timestamp { precision, zone } => { + 5u8.encode(writer, is_direct, reference_tables, arena)?; + precision.encode(writer, is_direct, reference_tables, arena)?; + zone.encode(writer, is_direct, reference_tables, arena) + } + CastEvaluatorParams::Tuple { evaluators } => { + 6u8.encode(writer, is_direct, reference_tables, arena)?; + evaluators.encode(writer, is_direct, reference_tables, arena) + } + } } - fn decode( + fn decode( reader: &mut R, context: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - let ty = LogicalType::decode(reader, context, reference_tables)?; - let op = BinaryOperator::decode(reader, context, reference_tables)?; - binary_create(Cow::Owned(ty), op) + Ok( + match u8::decode(reader, context, reference_tables, arena)? { + 0 => CastEvaluatorParams::Identity, + 1 => CastEvaluatorParams::Unit, + 2 => CastEvaluatorParams::String { + len: Option::::decode(reader, context, reference_tables, arena)?, + unit: crate::types::CharLengthUnits::decode( + reader, + context, + reference_tables, + arena, + )?, + }, + #[cfg(feature = "decimal")] + 3 => CastEvaluatorParams::Decimal { + precision: Option::::decode(reader, context, reference_tables, arena)?, + scale: Option::::decode(reader, context, reference_tables, arena)?, + }, + #[cfg(not(feature = "decimal"))] + 3 => { + let _ = Option::::decode(reader, context, reference_tables, arena)?; + let _ = Option::::decode(reader, context, reference_tables, arena)?; + return Err(DatabaseError::UnsupportedStmt( + "DECIMAL requires the `decimal` feature".to_string(), + )); + } + 4 => CastEvaluatorParams::Precision { + precision: Option::::decode(reader, context, reference_tables, arena)?, + }, + 5 => CastEvaluatorParams::Timestamp { + precision: Option::::decode(reader, context, reference_tables, arena)?, + zone: bool::decode(reader, context, reference_tables, arena)?, + }, + 6 => CastEvaluatorParams::Tuple { + evaluators: Vec::::decode( + reader, + context, + reference_tables, + arena, + )?, + }, + _ => unreachable!(), + }, + ) } } -impl ReferenceSerialization for CastEvaluatorBox { - fn encode( +impl ReferenceSerialization for UnaryEvaluatorRef { + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self.from.encode(writer, is_direct, reference_tables)?; - self.to.encode(writer, is_direct, reference_tables) + self.pos.encode(writer, is_direct, reference_tables, arena) } - fn decode( + fn decode( reader: &mut R, context: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - let from = LogicalType::decode(reader, context, reference_tables)?; - let to = LogicalType::decode(reader, context, reference_tables)?; - cast_create(Cow::Owned(from), Cow::Owned(to)) + Ok(UnaryEvaluatorRef::new(u16::decode( + reader, + context, + reference_tables, + arena, + )?)) + } +} + +impl ReferenceSerialization for BinaryEvaluatorRef { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + arena: &A, + ) -> Result<(), DatabaseError> { + self.pos + .encode(writer, is_direct, reference_tables, arena)?; + self.params + .encode(writer, is_direct, reference_tables, arena) + } + + fn decode( + reader: &mut R, + context: Option<&ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + arena: &mut A, + ) -> Result { + Ok(BinaryEvaluatorRef::new( + u16::decode(reader, context, reference_tables, arena)?, + BinaryEvaluatorParams::decode(reader, context, reference_tables, arena)?, + )) + } +} + +impl ReferenceSerialization for CastEvaluatorRef { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + arena: &A, + ) -> Result<(), DatabaseError> { + self.pos + .encode(writer, is_direct, reference_tables, arena)?; + self.params + .encode(writer, is_direct, reference_tables, arena) + } + + fn decode( + reader: &mut R, + context: Option<&ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + arena: &mut A, + ) -> Result { + Ok(CastEvaluatorRef::new( + u16::decode(reader, context, reference_tables, arena)?, + CastEvaluatorParams::decode(reader, context, reference_tables, arena)?, + )) + } +} + +#[cfg(all(test, not(target_arch = "wasm32")))] +mod tests { + use super::*; + use crate::storage::rocksdb::RocksTransaction; + use crate::types::CharLengthUnits; + use std::io::{Cursor, Seek, SeekFrom}; + + fn roundtrip_cast_params( + params: CastEvaluatorParams, + expected_tag: u8, + ) -> Result<(), DatabaseError> { + let mut cursor = Cursor::new(Vec::new()); + let mut reference_tables = ReferenceTables::new(); + let mut arena = crate::planner::TableArena::default(); + + params.encode(&mut cursor, false, &mut reference_tables, &arena)?; + assert_eq!(cursor.get_ref()[0], expected_tag); + cursor.seek(SeekFrom::Start(0))?; + + assert_eq!( + CastEvaluatorParams::decode::( + &mut cursor, + None, + &reference_tables, + &mut arena, + )?, + params + ); + + Ok(()) + } + + fn roundtrip_cast_ref(evaluator: CastEvaluatorRef) -> Result<(), DatabaseError> { + let mut cursor = Cursor::new(Vec::new()); + let mut reference_tables = ReferenceTables::new(); + let mut arena = crate::planner::TableArena::default(); + + evaluator.encode(&mut cursor, false, &mut reference_tables, &arena)?; + cursor.seek(SeekFrom::Start(0))?; + + assert_eq!( + CastEvaluatorRef::decode::( + &mut cursor, + None, + &reference_tables, + &mut arena, + )?, + evaluator + ); + + Ok(()) + } + + #[test] + fn cast_evaluator_params_serialization_roundtrips_all_variants() -> Result<(), DatabaseError> { + roundtrip_cast_params(CastEvaluatorParams::Identity, 0)?; + roundtrip_cast_params(CastEvaluatorParams::Unit, 1)?; + roundtrip_cast_params( + CastEvaluatorParams::String { + len: Some(12), + unit: CharLengthUnits::Octets, + }, + 2, + )?; + roundtrip_cast_params( + CastEvaluatorParams::String { + len: None, + unit: CharLengthUnits::Characters, + }, + 2, + )?; + #[cfg(feature = "decimal")] + roundtrip_cast_params( + CastEvaluatorParams::Decimal { + precision: Some(10), + scale: Some(2), + }, + 3, + )?; + #[cfg(feature = "decimal")] + roundtrip_cast_params( + CastEvaluatorParams::Decimal { + precision: None, + scale: None, + }, + 3, + )?; + roundtrip_cast_params(CastEvaluatorParams::Precision { precision: Some(6) }, 4)?; + roundtrip_cast_params(CastEvaluatorParams::Precision { precision: None }, 4)?; + roundtrip_cast_params( + CastEvaluatorParams::Timestamp { + precision: Some(3), + zone: true, + }, + 5, + )?; + roundtrip_cast_params( + CastEvaluatorParams::Timestamp { + precision: None, + zone: false, + }, + 5, + )?; + roundtrip_cast_params( + CastEvaluatorParams::Tuple { + evaluators: vec![ + CastEvaluatorRef::new(1, CastEvaluatorParams::Unit), + CastEvaluatorRef::new( + 2, + CastEvaluatorParams::String { + len: Some(8), + unit: CharLengthUnits::Characters, + }, + ), + ], + }, + 6, + )?; + + Ok(()) + } + + #[test] + fn cast_evaluator_ref_serialization_roundtrips_nested_tuple_params() -> Result<(), DatabaseError> + { + roundtrip_cast_ref(CastEvaluatorRef::new( + 42, + CastEvaluatorParams::Tuple { + evaluators: vec![ + #[cfg(feature = "decimal")] + CastEvaluatorRef::new( + 3, + CastEvaluatorParams::Decimal { + precision: Some(18), + scale: Some(4), + }, + ), + CastEvaluatorRef::new( + 4, + CastEvaluatorParams::Timestamp { + precision: Some(9), + zone: true, + }, + ), + ], + }, + )) } } diff --git a/src/serdes/function.rs b/src/serdes/function.rs index 9e45627a..fa177d12 100644 --- a/src/serdes/function.rs +++ b/src/serdes/function.rs @@ -21,21 +21,24 @@ use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for ArcScalarFunctionImpl { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self.summary().encode(writer, is_direct, reference_tables) + self.summary() + .encode(writer, is_direct, reference_tables, arena) } - fn decode( + fn decode( reader: &mut R, context: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - let summary = FunctionSummary::decode(reader, context, reference_tables)?; + let summary = FunctionSummary::decode(reader, context, reference_tables, arena)?; let Some(functions) = context.and_then(ReferenceDecodeContext::scala_functions) else { return Err(DatabaseError::InvalidValue(format!( "scalar function decode context missing for {}", @@ -54,21 +57,24 @@ impl ReferenceSerialization for ArcScalarFunctionImpl { } impl ReferenceSerialization for ArcTableFunctionImpl { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self.summary().encode(writer, is_direct, reference_tables) + self.summary() + .encode(writer, is_direct, reference_tables, arena) } - fn decode( + fn decode( reader: &mut R, context: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - let summary = FunctionSummary::decode(reader, context, reference_tables)?; + let summary = FunctionSummary::decode(reader, context, reference_tables, arena)?; let Some(functions) = context.and_then(ReferenceDecodeContext::table_functions) else { return Err(DatabaseError::InvalidValue(format!( "table function decode context missing for {}", @@ -82,6 +88,6 @@ impl ReferenceSerialization for ArcTableFunctionImpl { ))); }; - Ok(Self(function.clone())) + Ok(function.inner.clone()) } } diff --git a/src/serdes/hasher.rs b/src/serdes/hasher.rs index 0fe0da55..d527dec1 100644 --- a/src/serdes/hasher.rs +++ b/src/serdes/hasher.rs @@ -12,7 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::implement_serialization_by_bincode; +use crate::errors::DatabaseError; use crate::optimizer::core::cm_sketch::FastHasher; +use crate::serdes::{ReferenceSerialization, ReferenceTables}; +use crate::storage::Transaction; -implement_serialization_by_bincode!(FastHasher); +impl ReferenceSerialization for FastHasher { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + arena: &A, + ) -> Result<(), DatabaseError> { + let (key0, key1) = self.keys(); + key0.encode(writer, is_direct, reference_tables, arena)?; + key1.encode(writer, is_direct, reference_tables, arena) + } + + fn decode( + reader: &mut R, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + arena: &mut A, + ) -> Result { + let key0 = u64::decode(reader, drive, reference_tables, arena)?; + let key1 = u64::decode(reader, drive, reference_tables, arena)?; + + Ok(FastHasher::new_with_keys(key0, key1)) + } +} diff --git a/src/serdes/index.rs b/src/serdes/index.rs new file mode 100644 index 00000000..e676da14 --- /dev/null +++ b/src/serdes/index.rs @@ -0,0 +1,44 @@ +// Copyright 2024 KipData/KiteSQL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::errors::DatabaseError; +use crate::planner::MetaArena; +use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; +use crate::storage::Transaction; +use crate::types::index::{IndexMeta, IndexMetaRef}; +use std::io::{Read, Write}; + +impl ReferenceSerialization for IndexMetaRef { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + arena: &A, + ) -> Result<(), DatabaseError> { + arena + .index(*self) + .encode(writer, is_direct, reference_tables, arena) + } + + fn decode( + reader: &mut R, + drive: Option<&ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + arena: &mut A, + ) -> Result { + let index = IndexMeta::decode(reader, drive, reference_tables, arena)?; + Ok(arena.alloc_index(index)) + } +} diff --git a/src/serdes/mod.rs b/src/serdes/mod.rs index c92abb3d..777f2e8c 100644 --- a/src/serdes/mod.rs +++ b/src/serdes/mod.rs @@ -22,6 +22,7 @@ mod data_value; mod evaluator; mod function; mod hasher; +mod index; mod num; mod option; mod pair; @@ -37,48 +38,25 @@ mod vec; use crate::catalog::TableName; use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; +use crate::planner::MetaArena; use crate::storage::{TableCache, Transaction}; use std::io; use std::io::{Read, Write}; -#[macro_export] -macro_rules! implement_serialization_by_bincode { - ($struct_name:ident) => { - impl $crate::serdes::ReferenceSerialization for $struct_name { - fn encode( - &self, - writer: &mut W, - _: bool, - _: &mut $crate::serdes::ReferenceTables, - ) -> Result<(), $crate::errors::DatabaseError> { - bincode::serialize_into(writer, self)?; - - Ok(()) - } - - fn decode( - reader: &mut R, - _: Option<&$crate::serdes::ReferenceDecodeContext<'_, T>>, - _: &$crate::serdes::ReferenceTables, - ) -> Result { - Ok(bincode::deserialize_from(reader)?) - } - } - }; -} - pub trait ReferenceSerialization { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError>; - fn decode( + fn decode( reader: &mut R, context: Option<&ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result where Self: Sized; @@ -124,11 +102,19 @@ impl<'a, T: Transaction> ReferenceDecodeContext<'a, T> { } } -#[derive(Debug, Default, Eq, PartialEq)] +#[derive(Debug, Default)] pub struct ReferenceTables { tables: Vec, } +impl PartialEq for ReferenceTables { + fn eq(&self, other: &Self) -> bool { + self.tables == other.tables + } +} + +impl Eq for ReferenceTables {} + impl ReferenceTables { pub fn new() -> Self { ReferenceTables { tables: vec![] } @@ -138,6 +124,10 @@ impl ReferenceTables { self.tables.is_empty() } + pub fn clear(&mut self) { + self.tables.clear(); + } + pub fn len(&self) -> usize { self.tables.len() } @@ -193,9 +183,9 @@ mod tests { #[test] fn test_to_raw() -> io::Result<()> { - let reference_tables = ReferenceTables { - tables: vec!["t1".to_string().into(), "t2".to_string().into()], - }; + let mut reference_tables = ReferenceTables::new(); + reference_tables.push_or_replace(&"t1".to_string().into()); + reference_tables.push_or_replace(&"t2".to_string().into()); let mut cursor = io::Cursor::new(Vec::new()); reference_tables.to_raw(&mut cursor)?; diff --git a/src/serdes/num.rs b/src/serdes/num.rs index ef06e1b1..264040cb 100644 --- a/src/serdes/num.rs +++ b/src/serdes/num.rs @@ -23,21 +23,23 @@ use std::mem::size_of; macro_rules! implement_num_serialization { ($struct_name:ident) => { impl ReferenceSerialization for $struct_name { - fn encode( + fn encode( &self, writer: &mut W, _: bool, _: &mut ReferenceTables, + _: &A, ) -> Result<(), DatabaseError> { writer.write_all(&self.to_le_bytes()[..])?; Ok(()) } - fn decode( + fn decode( reader: &mut R, - _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, + _: Option<&$crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, + _: &mut A, ) -> Result { let mut bytes = [0u8; size_of::()]; reader.read_exact(&mut bytes)?; @@ -61,21 +63,23 @@ implement_num_serialization!(f32); implement_num_serialization!(f64); impl ReferenceSerialization for usize { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - (*self as u32).encode(writer, is_direct, reference_tables) + (*self as u32).encode(writer, is_direct, reference_tables, arena) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - Ok(u32::decode(reader, drive, reference_tables)? as usize) + Ok(u32::decode(reader, drive, reference_tables, arena)? as usize) } } @@ -103,43 +107,59 @@ pub(crate) mod test { let mut reference_tables = ReferenceTables::new(); let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); - - source_0.encode(&mut cursor, false, &mut reference_tables)?; - source_1.encode(&mut cursor, false, &mut reference_tables)?; - source_2.encode(&mut cursor, false, &mut reference_tables)?; - source_3.encode(&mut cursor, false, &mut reference_tables)?; - source_4.encode(&mut cursor, false, &mut reference_tables)?; - source_5.encode(&mut cursor, false, &mut reference_tables)?; - source_6.encode(&mut cursor, false, &mut reference_tables)?; - source_7.encode(&mut cursor, false, &mut reference_tables)?; - source_8.encode(&mut cursor, false, &mut reference_tables)?; - source_9.encode(&mut cursor, false, &mut reference_tables)?; - source_10.encode(&mut cursor, false, &mut reference_tables)?; + let mut arena = crate::planner::TableArena::default(); + + source_0.encode(&mut cursor, false, &mut reference_tables, &arena)?; + source_1.encode(&mut cursor, false, &mut reference_tables, &arena)?; + source_2.encode(&mut cursor, false, &mut reference_tables, &arena)?; + source_3.encode(&mut cursor, false, &mut reference_tables, &arena)?; + source_4.encode(&mut cursor, false, &mut reference_tables, &arena)?; + source_5.encode(&mut cursor, false, &mut reference_tables, &arena)?; + source_6.encode(&mut cursor, false, &mut reference_tables, &arena)?; + source_7.encode(&mut cursor, false, &mut reference_tables, &arena)?; + source_8.encode(&mut cursor, false, &mut reference_tables, &arena)?; + source_9.encode(&mut cursor, false, &mut reference_tables, &arena)?; + source_10.encode(&mut cursor, false, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; let decoded_0 = - u8::decode::(&mut cursor, None, &reference_tables).unwrap(); + u8::decode::(&mut cursor, None, &reference_tables, &mut arena) + .unwrap(); let decoded_1 = - u16::decode::(&mut cursor, None, &reference_tables).unwrap(); + u16::decode::(&mut cursor, None, &reference_tables, &mut arena) + .unwrap(); let decoded_2 = - u32::decode::(&mut cursor, None, &reference_tables).unwrap(); + u32::decode::(&mut cursor, None, &reference_tables, &mut arena) + .unwrap(); let decoded_3 = - u64::decode::(&mut cursor, None, &reference_tables).unwrap(); + u64::decode::(&mut cursor, None, &reference_tables, &mut arena) + .unwrap(); let decoded_4 = - i8::decode::(&mut cursor, None, &reference_tables).unwrap(); + i8::decode::(&mut cursor, None, &reference_tables, &mut arena) + .unwrap(); let decoded_5 = - i16::decode::(&mut cursor, None, &reference_tables).unwrap(); + i16::decode::(&mut cursor, None, &reference_tables, &mut arena) + .unwrap(); let decoded_6 = - i32::decode::(&mut cursor, None, &reference_tables).unwrap(); + i32::decode::(&mut cursor, None, &reference_tables, &mut arena) + .unwrap(); let decoded_7 = - i64::decode::(&mut cursor, None, &reference_tables).unwrap(); + i64::decode::(&mut cursor, None, &reference_tables, &mut arena) + .unwrap(); let decoded_8 = - f32::decode::(&mut cursor, None, &reference_tables).unwrap(); + f32::decode::(&mut cursor, None, &reference_tables, &mut arena) + .unwrap(); let decoded_9 = - f64::decode::(&mut cursor, None, &reference_tables).unwrap(); - let decoded_10 = - usize::decode::(&mut cursor, None, &reference_tables).unwrap(); + f64::decode::(&mut cursor, None, &reference_tables, &mut arena) + .unwrap(); + let decoded_10 = usize::decode::( + &mut cursor, + None, + &reference_tables, + &mut arena, + ) + .unwrap(); assert_eq!(source_0, decoded_0); assert_eq!(source_1, decoded_1); diff --git a/src/serdes/option.rs b/src/serdes/option.rs index 639d6b21..60b57991 100644 --- a/src/serdes/option.rs +++ b/src/serdes/option.rs @@ -21,31 +21,33 @@ impl ReferenceSerialization for Option where V: ReferenceSerialization, { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { match self { - None => 0u8.encode(writer, is_direct, reference_tables)?, + None => 0u8.encode(writer, is_direct, reference_tables, arena)?, Some(v) => { - 1u8.encode(writer, is_direct, reference_tables)?; - v.encode(writer, is_direct, reference_tables)?; + 1u8.encode(writer, is_direct, reference_tables, arena)?; + v.encode(writer, is_direct, reference_tables, arena)?; } } Ok(()) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - match u8::decode(reader, drive, reference_tables)? { + match u8::decode(reader, drive, reference_tables, arena)? { 0 => Ok(None), - 1 => Ok(Some(V::decode(reader, drive, reference_tables)?)), + 1 => Ok(Some(V::decode(reader, drive, reference_tables, arena)?)), _ => unreachable!(), } } diff --git a/src/serdes/pair.rs b/src/serdes/pair.rs index 58c19ac3..f9c9b904 100644 --- a/src/serdes/pair.rs +++ b/src/serdes/pair.rs @@ -22,26 +22,28 @@ where A: ReferenceSerialization, B: ReferenceSerialization, { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &AR, ) -> Result<(), DatabaseError> { let (v1, v2) = self; - v1.encode(writer, is_direct, reference_tables)?; - v2.encode(writer, is_direct, reference_tables)?; + v1.encode(writer, is_direct, reference_tables, arena)?; + v2.encode(writer, is_direct, reference_tables, arena)?; Ok(()) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut AR, ) -> Result { - let v1 = A::decode(reader, drive, reference_tables)?; - let v2 = B::decode(reader, drive, reference_tables)?; + let v1 = A::decode(reader, drive, reference_tables, arena)?; + let v2 = B::decode(reader, drive, reference_tables, arena)?; Ok((v1, v2)) } diff --git a/src/serdes/path_buf.rs b/src/serdes/path_buf.rs index aba7b38e..9ecb6117 100644 --- a/src/serdes/path_buf.rs +++ b/src/serdes/path_buf.rs @@ -12,7 +12,37 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::implement_serialization_by_bincode; +use crate::errors::DatabaseError; +use crate::serdes::{ReferenceSerialization, ReferenceTables}; +use crate::storage::Transaction; use std::path::PathBuf; +use std::str::FromStr; -implement_serialization_by_bincode!(PathBuf); +impl ReferenceSerialization for PathBuf { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + arena: &A, + ) -> Result<(), DatabaseError> { + self.to_str() + .ok_or_else(|| DatabaseError::InvalidValue("path is not valid utf8".to_string()))? + .to_string() + .encode(writer, is_direct, reference_tables, arena) + } + + fn decode( + reader: &mut R, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + arena: &mut A, + ) -> Result { + Ok(PathBuf::from_str(&String::decode( + reader, + drive, + reference_tables, + arena, + )?)?) + } +} diff --git a/src/serdes/phantom.rs b/src/serdes/phantom.rs index 684097da..40dad173 100644 --- a/src/serdes/phantom.rs +++ b/src/serdes/phantom.rs @@ -19,19 +19,21 @@ use std::io::{Read, Write}; use std::marker::PhantomData; impl ReferenceSerialization for PhantomData { - fn encode( + fn encode( &self, _: &mut W, _: bool, _: &mut ReferenceTables, + _: &A, ) -> Result<(), DatabaseError> { Ok(()) } - fn decode( + fn decode( _: &mut R, _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, + _: &mut A, ) -> Result { Ok(PhantomData) } diff --git a/src/serdes/ptr.rs b/src/serdes/ptr.rs index 9cba60ac..7f0a5c69 100644 --- a/src/serdes/ptr.rs +++ b/src/serdes/ptr.rs @@ -25,19 +25,22 @@ macro_rules! implement_ptr_serialization { where V: ReferenceSerialization, { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self.as_ref().encode(writer, is_direct, reference_tables) + self.as_ref() + .encode(writer, is_direct, reference_tables, arena) } - fn decode( + fn decode( reader: &mut R, - drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, + drive: Option<&$crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result where Self: Sized, @@ -46,6 +49,7 @@ macro_rules! implement_ptr_serialization { reader, drive, reference_tables, + arena, )?)) } } diff --git a/src/serdes/slice.rs b/src/serdes/slice.rs index 523ffa9c..97dcd019 100644 --- a/src/serdes/slice.rs +++ b/src/serdes/slice.rs @@ -21,26 +21,28 @@ impl ReferenceSerialization for [V; 2] where V: ReferenceSerialization, { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self[0].encode(writer, is_direct, reference_tables)?; - self[1].encode(writer, is_direct, reference_tables)?; + self[0].encode(writer, is_direct, reference_tables, arena)?; + self[1].encode(writer, is_direct, reference_tables, arena)?; Ok(()) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { Ok([ - V::decode(reader, drive, reference_tables)?, - V::decode(reader, drive, reference_tables)?, + V::decode(reader, drive, reference_tables, arena)?, + V::decode(reader, drive, reference_tables, arena)?, ]) } } diff --git a/src/serdes/string.rs b/src/serdes/string.rs index 720b9754..e5539d3d 100644 --- a/src/serdes/string.rs +++ b/src/serdes/string.rs @@ -13,31 +13,62 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::implement_serialization_by_bincode; use crate::serdes::{ReferenceSerialization, ReferenceTables}; use crate::storage::Transaction; +use std::io::{Read, Write}; use std::sync::Arc; -implement_serialization_by_bincode!(String); +impl ReferenceSerialization for String { + fn encode( + &self, + writer: &mut W, + is_direct: bool, + reference_tables: &mut ReferenceTables, + arena: &A, + ) -> Result<(), DatabaseError> { + self.len() + .encode(writer, is_direct, reference_tables, arena)?; + writer.write_all(self.as_bytes())?; + + Ok(()) + } + + fn decode( + reader: &mut R, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + arena: &mut A, + ) -> Result { + let len = usize::decode(reader, drive, reference_tables, arena)?; + let mut bytes = vec![0; len]; + reader.read_exact(&mut bytes)?; + + Ok(String::from_utf8(bytes)?) + } +} impl ReferenceSerialization for Arc { - fn encode( + fn encode( &self, writer: &mut W, - _: bool, - _: &mut ReferenceTables, + is_direct: bool, + reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - bincode::serialize_into(writer, self)?; + self.len() + .encode(writer, is_direct, reference_tables, arena)?; + writer.write_all(self.as_bytes())?; Ok(()) } - fn decode( + fn decode( reader: &mut R, - _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, - _: &ReferenceTables, + drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, + reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - let str: String = bincode::deserialize_from(reader)?; + let str = String::decode(reader, drive, reference_tables, arena)?; Ok(str.into()) } } @@ -54,12 +85,18 @@ pub(crate) mod test { let mut bytes = Vec::new(); let mut cursor = Cursor::new(&mut bytes); let mut reference_tables = ReferenceTables::new(); + let mut arena = crate::planner::TableArena::default(); let source = "hello".to_string(); - ReferenceSerialization::encode(&source, &mut cursor, true, &mut reference_tables)?; + ReferenceSerialization::encode(&source, &mut cursor, true, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; assert_eq!( - String::decode::(&mut cursor, None, &reference_tables)?, + String::decode::( + &mut cursor, + None, + &reference_tables, + &mut arena, + )?, source ); diff --git a/src/serdes/trim.rs b/src/serdes/trim.rs index d9aa39ca..42ca63fc 100644 --- a/src/serdes/trim.rs +++ b/src/serdes/trim.rs @@ -19,11 +19,12 @@ use crate::storage::Transaction; use std::io::{Read, Write}; impl ReferenceSerialization for TrimWhereField { - fn encode( + fn encode( &self, writer: &mut W, _: bool, _: &mut ReferenceTables, + _: &A, ) -> Result<(), DatabaseError> { let type_id = match self { TrimWhereField::Both => 0, @@ -35,10 +36,11 @@ impl ReferenceSerialization for TrimWhereField { Ok(()) } - fn decode( + fn decode( reader: &mut R, _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, + _: &mut A, ) -> Result { let mut one_byte = [0u8; 1]; reader.read_exact(&mut one_byte)?; diff --git a/src/serdes/ulid.rs b/src/serdes/ulid.rs index 52804d5b..e0cc7da9 100644 --- a/src/serdes/ulid.rs +++ b/src/serdes/ulid.rs @@ -19,21 +19,23 @@ use std::io::{Read, Write}; use ulid::Ulid; impl ReferenceSerialization for Ulid { - fn encode( + fn encode( &self, writer: &mut W, _: bool, _: &mut ReferenceTables, + _: &A, ) -> Result<(), DatabaseError> { writer.write_all(&self.to_bytes()[..])?; Ok(()) } - fn decode( + fn decode( reader: &mut R, _: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, _: &ReferenceTables, + _: &mut A, ) -> Result { let mut buf = [0u8; 16]; reader.read_exact(&mut buf)?; diff --git a/src/serdes/vec.rs b/src/serdes/vec.rs index 97f69aff..6e3a318b 100644 --- a/src/serdes/vec.rs +++ b/src/serdes/vec.rs @@ -21,28 +21,32 @@ impl ReferenceSerialization for Vec where V: ReferenceSerialization, { - fn encode( + fn encode( &self, writer: &mut W, is_direct: bool, reference_tables: &mut ReferenceTables, + arena: &A, ) -> Result<(), DatabaseError> { - self.len().encode(writer, is_direct, reference_tables)?; + self.len() + .encode(writer, is_direct, reference_tables, arena)?; for value in self.iter() { - value.encode(writer, is_direct, reference_tables)? + value.encode(writer, is_direct, reference_tables, arena)? } Ok(()) } - fn decode( + fn decode( reader: &mut R, drive: Option<&crate::serdes::ReferenceDecodeContext<'_, T>>, reference_tables: &ReferenceTables, + arena: &mut A, ) -> Result { - let len = ::decode(reader, drive, reference_tables)?; + let len = + ::decode(reader, drive, reference_tables, arena)?; let mut vec = Vec::with_capacity(len); for _ in 0..len { - vec.push(V::decode(reader, drive, reference_tables)?); + vec.push(V::decode(reader, drive, reference_tables, arena)?); } Ok(vec) } diff --git a/src/storage/lmdb.rs b/src/storage/lmdb.rs index 6c8685e6..d3f0c06d 100644 --- a/src/storage/lmdb.rs +++ b/src/storage/lmdb.rs @@ -13,10 +13,10 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::storage::table_codec::{Bytes, TableCodec}; +use crate::storage::table_codec::Bytes; use crate::storage::{ - reuse_bound_as_excluded, InnerIter, KeyValueRef, Storage, Transaction, - TransactionIsolationLevel, + bytes_bound_as_slice, owned_bound, reuse_bound_as_excluded, InnerIter, KeyValueRef, Storage, + Transaction, TransactionIsolationLevel, }; use lmdb::{ Cursor, Database, DatabaseFlags, Environment, EnvironmentFlags, RoCursor, RwTransaction, @@ -107,10 +107,8 @@ impl LmdbStorage { if let Some(max_dbs) = config.max_dbs { builder.set_max_dbs(max_dbs); } - let env = builder.open(&path).map_err(map_lmdb_err)?; - let db = env - .create_db(None, DatabaseFlags::empty()) - .map_err(map_lmdb_err)?; + let env = builder.open(&path)?; + let db = env.create_db(None, DatabaseFlags::empty())?; Ok(Self { env: Arc::new(env), @@ -133,13 +131,9 @@ impl Storage for LmdbStorage { isolation: TransactionIsolationLevel, ) -> Result, DatabaseError> { self.validate_transaction_isolation(isolation)?; - let tx = self.env.begin_rw_txn().map_err(map_lmdb_err)?; + let tx = self.env.begin_rw_txn()?; - Ok(LmdbTransaction { - tx, - db: self.db, - table_codec: Default::default(), - }) + Ok(LmdbTransaction { tx, db: self.db }) } fn default_transaction_isolation(&self) -> TransactionIsolationLevel { @@ -167,7 +161,6 @@ impl Storage for LmdbStorage { pub struct LmdbTransaction<'env> { tx: RwTransaction<'env>, db: Database, - table_codec: TableCodec, } pub struct LmdbIter<'txn> { @@ -181,7 +174,7 @@ pub struct LmdbIter<'txn> { impl LmdbIter<'_> { fn next_visible(&mut self) -> Option<(&[u8], &[u8])> { if let Some(entry) = self.pending.take() { - if within_upper_bound(entry.0, &self.max) { + if within_upper_bound(entry.0, bytes_bound_as_slice(&self.max)) { return Some(entry); } self.done = true; @@ -189,7 +182,7 @@ impl LmdbIter<'_> { } if let Some((key, value)) = self.iter.next() { - if !within_upper_bound(key, &self.max) { + if !within_upper_bound(key, bytes_bound_as_slice(&self.max)) { self.done = true; return None; } @@ -220,11 +213,6 @@ impl Transaction for LmdbTransaction<'_> { = LmdbIter<'a> where Self: 'a; - - fn table_codec(&self) -> *const TableCodec { - &self.table_codec - } - fn get_borrowed<'a>( &'a self, key: &[u8], @@ -232,21 +220,20 @@ impl Transaction for LmdbTransaction<'_> { match self.tx.get(self.db, &key) { Ok(value) => Ok(Some(value)), Err(lmdb::Error::NotFound) => Ok(None), - Err(err) => Err(map_lmdb_err(err)), + Err(err) => Err(err.into()), } } fn set(&mut self, key: &[u8], value: &[u8]) -> Result<(), DatabaseError> { self.tx - .put(self.db, &key, &value, lmdb::WriteFlags::empty()) - .map_err(map_lmdb_err)?; + .put(self.db, &key, &value, lmdb::WriteFlags::empty())?; Ok(()) } fn remove(&mut self, key: &[u8]) -> Result<(), DatabaseError> { match self.tx.del(self.db, &key, None) { Ok(()) | Err(lmdb::Error::NotFound) => Ok(()), - Err(err) => Err(map_lmdb_err(err)), + Err(err) => Err(err.into()), } } @@ -255,8 +242,8 @@ impl Transaction for LmdbTransaction<'_> { min: Bound<&'key [u8]>, max: Bound<&'key [u8]>, ) -> Result, DatabaseError> { - let mut cursor = self.tx.open_ro_cursor(self.db).map_err(map_lmdb_err)?; - let (pending, done) = initial_entry(&mut cursor, &min).map_err(map_lmdb_err)?; + let mut cursor = self.tx.open_ro_cursor(self.db)?; + let (pending, done) = initial_entry(&mut cursor, &min)?; let iter = cursor.iter(); Ok(LmdbIter { @@ -269,27 +256,26 @@ impl Transaction for LmdbTransaction<'_> { } fn remove_range(&mut self, min: Bound<&[u8]>, max: Bound<&[u8]>) -> Result<(), DatabaseError> { - let mut cursor = self.tx.open_rw_cursor(self.db).map_err(map_lmdb_err)?; - let upper = owned_bound(max); + let mut cursor = self.tx.open_rw_cursor(self.db)?; let mut lower = owned_bound(min); let mut seek_key = Bytes::new(); loop { - let entry = cursor_seek(&mut cursor, &lower, &mut seek_key).map_err(map_lmdb_err)?; + let entry = cursor_seek(&mut cursor, &lower, &mut seek_key)?; let Some((key, _)) = entry else { return Ok(()); }; - if !within_upper_bound(key, &upper) { + if !within_upper_bound(key, max) { return Ok(()); } reuse_bound_as_excluded(&mut lower, key); - cursor.del(WriteFlags::empty()).map_err(map_lmdb_err)?; + cursor.del(WriteFlags::empty())?; } } fn commit(self) -> Result<(), DatabaseError> { - self.tx.commit().map_err(map_lmdb_err)?; + self.tx.commit()?; Ok(()) } } @@ -351,30 +337,18 @@ fn cursor_seek<'txn>( } } -fn owned_bound(bound: Bound<&[u8]>) -> Bound { - match bound { - Bound::Included(bytes) => Bound::Included(bytes.to_vec()), - Bound::Excluded(bytes) => Bound::Excluded(bytes.to_vec()), - Bound::Unbounded => Bound::Unbounded, - } -} - -fn within_upper_bound(key: &[u8], max: &Bound) -> bool { +fn within_upper_bound(key: &[u8], max: Bound<&[u8]>) -> bool { match max { - Bound::Included(max) => key.cmp(max.as_slice()) != Ordering::Greater, - Bound::Excluded(max) => key.cmp(max.as_slice()) == Ordering::Less, + Bound::Included(max) => key.cmp(max) != Ordering::Greater, + Bound::Excluded(max) => key.cmp(max) == Ordering::Less, Bound::Unbounded => true, } } -fn map_lmdb_err(err: impl std::fmt::Display) -> DatabaseError { - DatabaseError::InvalidValue(format!("lmdb: {err}")) -} - #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { use super::{LmdbConfig, LmdbStorage}; - use crate::db::DataBaseBuilder; + use crate::db::{CatalogKind, DataBaseBuilder}; use lmdb::EnvironmentFlags; use tempfile::TempDir; @@ -382,12 +356,13 @@ mod tests { fn lmdb_backend_smoke() { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let db_path = temp_dir.path().join("kite_sql.lmdb"); - let kite_sql = DataBaseBuilder::path(db_path).build_lmdb().unwrap(); + let mut kite_sql = DataBaseBuilder::path(db_path).build_lmdb().unwrap(); kite_sql - .run("create table t1 (a int primary key, b int)") - .unwrap() - .done() + .ddl("create table t1 (a int primary key, b int)") + .unwrap(); + kite_sql + .load(CatalogKind::Table("t1".to_string().into())) .unwrap(); kite_sql .run("insert into t1 values (1, 10), (2, 20), (3, 30)") @@ -406,22 +381,18 @@ mod tests { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let db_path = temp_dir.path().join("kite_sql.lmdb"); let storage = LmdbStorage::new(db_path).unwrap(); - let kite_sql = DataBaseBuilder::path(temp_dir.path()) + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()) .build_with_storage(storage) .unwrap(); - kite_sql - .run("create table t1 (a int primary key)") - .unwrap() - .done() - .unwrap(); + kite_sql.ddl("create table t1 (a int primary key)").unwrap(); } #[test] fn collect_lmdb_metrics_snapshot() { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let db_path = temp_dir.path().join("kite_sql.lmdb"); - let kite_sql = DataBaseBuilder::path(db_path) + let mut kite_sql = DataBaseBuilder::path(db_path) .storage_statistics(true) .lmdb_flags(EnvironmentFlags::NO_SYNC) .lmdb_map_size(64 * 1024 * 1024) @@ -429,9 +400,10 @@ mod tests { .unwrap(); kite_sql - .run("create table t_metrics (a int primary key, b int)") - .unwrap() - .done() + .ddl("create table t_metrics (a int primary key, b int)") + .unwrap(); + kite_sql + .load(CatalogKind::Table("t_metrics".to_string().into())) .unwrap(); kite_sql .run("insert into t_metrics values (1, 10), (2, 20), (3, 30)") @@ -457,14 +429,10 @@ mod tests { }, ) .unwrap(); - let kite_sql = DataBaseBuilder::path(temp_dir.path()) + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()) .build_with_storage(storage) .unwrap(); - kite_sql - .run("create table t1 (a int primary key)") - .unwrap() - .done() - .unwrap(); + kite_sql.ddl("create table t1 (a int primary key)").unwrap(); } } diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 6e37a02c..61ab6434 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::storage::table_codec::{Bytes, TableCodec}; +use crate::storage::table_codec::Bytes; use crate::storage::{ EmptyStorageMetrics, InnerIter, Storage, Transaction, TransactionIsolationLevel, }; @@ -47,7 +47,6 @@ impl Storage for MemoryStorage { self.validate_transaction_isolation(isolation)?; Ok(MemoryTransaction { inner: self.inner.clone(), - table_codec: Default::default(), }) } @@ -58,7 +57,6 @@ impl Storage for MemoryStorage { pub struct MemoryTransaction { inner: Rc, Vec>>>, - table_codec: TableCodec, } pub struct MemoryIter { @@ -97,11 +95,6 @@ impl Transaction for MemoryTransaction { = MemoryIter where Self: 'a; - - fn table_codec(&self) -> *const TableCodec { - &self.table_codec - } - fn get_borrowed<'a>( &'a self, key: &[u8], @@ -160,49 +153,52 @@ impl Transaction for MemoryTransaction { #[cfg(all(test, target_arch = "wasm32"))] mod wasm_tests { use super::*; - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, TableName}; - use crate::db::DataBaseBuilder; + use crate::catalog::{ColumnCatalog, ColumnDesc, TableName}; + use crate::db::{CatalogKind, DataBaseBuilder}; use crate::expression::range_detacher::Range; + use crate::planner::{PlanArena, TableArenaCell}; + use crate::storage::table_codec::TableCodec; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; - use itertools::Itertools; use std::collections::Bound; - use std::hash::RandomState; - use std::sync::Arc; use wasm_bindgen_test::*; #[wasm_bindgen_test] fn memory_storage_roundtrip() -> Result<(), DatabaseError> { let storage = MemoryStorage::new(); let mut transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let columns = Arc::new(vec![ - ColumnRef::from(ColumnCatalog::new( + let mut table_cache = crate::storage::TableCache::default(); + let mut table_codec = TableCodec::default(); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); + let source_columns = vec![ + ColumnCatalog::new( "c1".to_string(), false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c2".to_string(), false, ColumnDesc::new(LogicalType::Boolean, None, false, None).unwrap(), - )), - ]); - - let source_columns = columns - .iter() - .map(|col_ref| ColumnCatalog::clone(col_ref)) - .collect_vec(); - transaction.create_table( - &table_cache, + ), + ]; + if let Some(table) = transaction.create_table( + &mut table_codec, + &mut plan_arena, "test".to_string().into(), source_columns, false, - )?; + )? { + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); + } + let plan_arena = PlanArena::new(&table_arena); + let table_name: TableName = "test".to_string().into(); transaction.append_tuple( + &mut table_codec, "test", Tuple::new( Some(DataValue::Int32(1)), @@ -215,6 +211,7 @@ mod wasm_tests { false, )?; transaction.append_tuple( + &mut table_codec, "test", Tuple::new( Some(DataValue::Int32(2)), @@ -227,9 +224,18 @@ mod wasm_tests { false, )?; - let read_columns = vec![columns[0].clone()]; + let read_column = table_cache + .get(&table_name) + .unwrap() + .columns() + .next() + .copied() + .unwrap(); + let read_columns = vec![read_column]; let mut iter = transaction.read( + &mut table_codec, + &plan_arena, &table_cache, "test".to_string().into(), (Some(1), Some(1)), @@ -248,10 +254,9 @@ mod wasm_tests { #[wasm_bindgen_test] fn memory_storage_read_by_index() -> Result<(), DatabaseError> { - let kite_sql = DataBaseBuilder::path("./memory").build_in_memory()?; - kite_sql - .run("create table t1 (a int primary key, b int)")? - .done()?; + let mut kite_sql = DataBaseBuilder::path("./memory").build_in_memory()?; + kite_sql.ddl("create table t1 (a int primary key, b int)")?; + kite_sql.load(CatalogKind::Table("t1".to_string().into()))?; kite_sql .run("insert into t1 (a, b) values (0, 0), (1, 1), (2, 2), (3, 4)")? .done()?; @@ -263,8 +268,10 @@ mod wasm_tests { .unwrap() .clone(); let pk_index = table.indexes().next().unwrap().clone(); + let plan_arena = PlanArena::new(kite_sql.state.table_arena()); let mut iter = transaction.read_by_index( kite_sql.state.table_cache(), + &plan_arena, table_name, (Some(0), None), table.columns().cloned().collect(), @@ -292,48 +299,51 @@ mod wasm_tests { #[cfg(all(test, not(target_arch = "wasm32")))] mod native_tests { use super::*; - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, TableName}; - use crate::db::DataBaseBuilder; + use crate::catalog::{ColumnCatalog, ColumnDesc, TableName}; + use crate::db::{CatalogKind, DataBaseBuilder}; use crate::expression::range_detacher::Range; + use crate::planner::{PlanArena, TableArenaCell}; + use crate::storage::table_codec::TableCodec; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; - use itertools::Itertools; use std::collections::Bound; - use std::hash::RandomState; - use std::sync::Arc; #[test] fn memory_storage_roundtrip() -> Result<(), DatabaseError> { let storage = MemoryStorage::new(); let mut transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let columns = Arc::new(vec![ - ColumnRef::from(ColumnCatalog::new( + let mut table_cache = crate::storage::TableCache::default(); + let mut table_codec = TableCodec::default(); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); + let source_columns = vec![ + ColumnCatalog::new( "c1".to_string(), false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c2".to_string(), false, ColumnDesc::new(LogicalType::Boolean, None, false, None).unwrap(), - )), - ]); - - let source_columns = columns - .iter() - .map(|col_ref| ColumnCatalog::clone(col_ref)) - .collect_vec(); - transaction.create_table( - &table_cache, + ), + ]; + if let Some(table) = transaction.create_table( + &mut table_codec, + &mut plan_arena, "test".to_string().into(), source_columns, false, - )?; + )? { + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); + } + let plan_arena = PlanArena::new(&table_arena); + let table_name: TableName = "test".to_string().into(); transaction.append_tuple( + &mut table_codec, "test", Tuple::new( Some(DataValue::Int32(1)), @@ -346,6 +356,7 @@ mod native_tests { false, )?; transaction.append_tuple( + &mut table_codec, "test", Tuple::new( Some(DataValue::Int32(2)), @@ -358,9 +369,18 @@ mod native_tests { false, )?; - let read_columns = vec![columns[0].clone()]; + let read_column = table_cache + .get(&table_name) + .unwrap() + .columns() + .next() + .copied() + .unwrap(); + let read_columns = vec![read_column]; let mut iter = transaction.read( + &mut table_codec, + &plan_arena, &table_cache, "test".to_string().into(), (Some(1), Some(1)), @@ -379,10 +399,9 @@ mod native_tests { #[test] fn memory_storage_read_by_index() -> Result<(), DatabaseError> { - let kite_sql = DataBaseBuilder::path("./memory").build_in_memory()?; - kite_sql - .run("create table t1 (a int primary key, b int)")? - .done()?; + let mut kite_sql = DataBaseBuilder::path("./memory").build_in_memory()?; + kite_sql.ddl("create table t1 (a int primary key, b int)")?; + kite_sql.load(CatalogKind::Table("t1".to_string().into()))?; kite_sql .run("insert into t1 (a, b) values (0, 0), (1, 1), (2, 2), (3, 4)")? .done()?; @@ -393,9 +412,11 @@ mod native_tests { .table(kite_sql.state.table_cache(), table_name.clone())? .unwrap() .clone(); - let pk_index = table.indexes().next().unwrap().clone(); + let pk_index = *table.indexes().next().unwrap(); + let plan_arena = PlanArena::new(kite_sql.state.table_arena()); let mut iter = transaction.read_by_index( kite_sql.state.table_cache(), + &plan_arena, table_name, (Some(0), None), table.columns().cloned().collect(), diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 3a6d0cd8..40740d17 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -28,34 +28,32 @@ use crate::expression::ScalarExpression; use crate::optimizer::core::cm_sketch::{ CountMinSketch, CountMinSketchPage, COUNT_MIN_SKETCH_STORAGE_PAGE_LEN, }; -use crate::optimizer::core::statistics_meta::{StatisticMetaLoader, StatisticsMeta}; +use crate::optimizer::core::statistics_meta::StatisticsMeta; use crate::planner::operator::alter_table::change_column::{DefaultChange, NotNullChange}; +use crate::planner::{MetaArena, PlanArena, TableArenaCell}; use crate::serdes::ReferenceTables; -use crate::storage::table_codec::{ - BumpBytes, Bytes, StatisticsCodecType, TableCodec, BOUND_MAX_TAG, -}; +use crate::storage::table_codec::{Bytes, StatisticsCodecType, TableCodec, BOUND_MAX_TAG}; use crate::types::index::{Index, IndexId, IndexMeta, IndexMetaRef, IndexType}; use crate::types::serialize::TupleValueSerializableImpl; use crate::types::tuple::{Tuple, TupleId}; use crate::types::value::{DataValue, TupleMappingRef}; use crate::types::{ColumnId, LogicalType}; -use crate::utils::lru::SharedLruCache; +use ahash::HashMap; use itertools::Itertools; -use std::borrow::Cow; +use std::borrow::{Borrow, Cow}; use std::collections::Bound; use std::fmt::{self, Display, Formatter}; use std::io::Cursor; use std::mem; use std::ops::SubAssign; use std::path::Path; -use std::sync::Arc; pub type KeyValueRef<'a> = (&'a [u8], &'a [u8]); use ulid::Generator; -pub(crate) type StatisticsMetaCache = SharedLruCache<(TableName, IndexId), Option>; -pub(crate) type TableCache = SharedLruCache; -pub(crate) type ViewCache = SharedLruCache; +pub(crate) type StatisticsMetaCache = HashMap<(TableName, IndexId), StatisticsMeta>; +pub(crate) type TableCache = HashMap; +pub(crate) type ViewCache = HashMap; /// Transaction isolation levels supported by KiteSQL. /// @@ -80,12 +78,14 @@ impl Display for TransactionIsolationLevel { pub(crate) fn index_value_type( table: &TableCatalog, + arena: &impl MetaArena, column_ids: &[ColumnId], ) -> Result { let mut value_types = Vec::with_capacity(column_ids.len()); for column_id in column_ids { let value_type = table .get_column_by_id(column_id) + .map(|column| arena.column(column)) .ok_or_else(|| DatabaseError::column_not_found(column_id.to_string()))? .datatype() .clone(); @@ -164,8 +164,6 @@ pub trait Transaction: Sized { where Self: 'a; - fn table_codec(&self) -> *const TableCodec; - fn begin_statement_scope(&mut self) -> Result<(), DatabaseError> { Ok(()) } @@ -177,9 +175,12 @@ pub trait Transaction: Sized { /// The bounds is applied to the whole data batches, not per batch. /// /// The projections is column indices. + #[allow(clippy::too_many_arguments)] fn read<'a>( &'a self, - table_cache: &'a TableCache, + table_codec: &mut TableCodec, + arena: &PlanArena, + table_cache: &TableCache, table_name: TableName, bounds: Bounds, columns: Vec, @@ -188,11 +189,11 @@ pub trait Transaction: Sized { let table = self .table(table_cache, table_name.clone())? .ok_or(DatabaseError::TableNotFound)?; - let deserializers = Self::create_deserializers(&columns, table, with_pk); + let deserializers = Self::create_deserializers(&columns, table, arena, with_pk); let pk_ty = with_pk.then(|| table.primary_keys_type().clone()); let offset = bounds.0.unwrap_or(0); - unsafe { &*self.table_codec() }.with_tuple_bound(&table_name, |min, max| { + table_codec.with_tuple_bound(&table_name, |min, max| { let iter = self.range(Bound::Included(min), Bound::Included(max))?; Ok(TupleIter { @@ -208,7 +209,8 @@ pub trait Transaction: Sized { #[allow(clippy::too_many_arguments)] fn read_by_index<'a, R>( &'a self, - table_cache: &'a TableCache, + table_cache: &TableCache, + arena: &PlanArena<'a>, table_name: TableName, (offset_option, limit_option): Bounds, columns: Vec, @@ -221,10 +223,11 @@ pub trait Transaction: Sized { where R: Into, { + let index_meta_ref = index_meta; + let index_meta = arena.index(index_meta_ref); let table = self .table(table_cache, table_name.clone())? .ok_or(DatabaseError::TableNotFound)?; - let table_name = table.name.as_ref(); let offset = offset_option.unwrap_or(0); let is_primary_index = matches!(index_meta.ty, IndexType::PrimaryKey { .. }); @@ -247,18 +250,20 @@ pub trait Transaction: Sized { ) } _ => { - let deserializers = Self::create_deserializers(&columns, table, with_pk); + let deserializers = Self::create_deserializers(&columns, table, arena, with_pk); (IndexImplEnum::instance(index_meta.ty), deserializers, None) } }; + let total_len = table.columns_len(); Ok(IndexIter { bounds: IterBounds::new(offset, limit_option), params: IndexImplParams { - index_meta, + index_meta: index_meta_ref, + meta_arena: arena.table_arena_cell().borrow(), table_name, deserializers, - total_len: table.columns_len(), + total_len, tx: self, cover_mapping, with_pk, @@ -274,6 +279,7 @@ pub trait Transaction: Sized { fn create_deserializers( columns: &[ColumnRef], table: &TableCatalog, + arena: &PlanArena, with_pk: bool, ) -> Vec { let mut pk_len = if with_pk { @@ -284,7 +290,8 @@ pub trait Transaction: Sized { let mut deserializers = Vec::with_capacity(table.columns_len()); let mut columns = columns.iter().peekable(); - for table_column in table.columns() { + for table_column_ref in table.columns() { + let table_column = arena.column(*table_column_ref); if columns.peek().is_none() && pk_len == 0 { break; } @@ -292,7 +299,7 @@ pub trait Transaction: Sized { let is_primary_key = with_pk && table_column.desc().primary().is_some(); if columns .peek() - .is_some_and(|column| same_projection_column(column, table_column)) + .is_some_and(|column| arena.same_column(**column, *table_column_ref)) { deserializers.push(table_column.datatype().serializable()); columns.next(); @@ -313,30 +320,29 @@ pub trait Transaction: Sized { fn add_index_meta( &mut self, - table_cache: &TableCache, + table_codec: &mut TableCodec, + plan_arena: &mut PlanArena, table_name: &TableName, index_name: String, column_ids: Vec, ty: IndexType, - ) -> Result { - if let Some(mut table) = self.table(table_cache, table_name.clone())?.cloned() { - let index_meta = table.add_index_meta(index_name, column_ids, ty)?; - let value = unsafe { &*self.table_codec() }.encode_index_meta_value(index_meta)?; - unsafe { &*self.table_codec() }.with_index_meta_key( - table_name, - index_meta.id, - |key| self.set(key, value.as_slice()), - )?; - table_cache.remove(table_name); + ) -> Result<(TableCatalog, IndexId), DatabaseError> { + let mut table = self + .load_table(table_codec, plan_arena, table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + let index_meta = table.add_index_meta(index_name, column_ids, ty, plan_arena)?; + let index_meta = plan_arena.index(index_meta); + let index_id = index_meta.id; + table_codec.with_index_meta(table_name, index_id, Some(index_meta), |key, value| { + self.set(key, value) + })?; - Ok(index_meta.id) - } else { - Err(DatabaseError::TableNotFound) - } + Ok((table, index_id)) } fn add_index( &mut self, + table_codec: &mut TableCodec, table_name: &str, index: Index, tuple_id: &TupleId, @@ -344,25 +350,23 @@ pub trait Transaction: Sized { if matches!(index.ty, IndexType::PrimaryKey { .. }) { return Ok(()); } - let mut value = BumpBytes::new_in(unsafe { &*self.table_codec() }.arena()); - bincode::serialize_into(&mut value, tuple_id)?; - - unsafe { &*self.table_codec() }.with_index_key(table_name, &index, Some(tuple_id), |key| { + table_codec.with_index(table_name, &index, Some(tuple_id), |key, value| { if matches!(index.ty, IndexType::Unique) { if let Some(bytes) = self.get_borrowed(key)? { - return if bytes.as_ref() != value.as_slice() { + return if bytes.as_ref() != value { Err(DatabaseError::DuplicateUniqueValue) } else { Ok(()) }; } } - self.set(key, value.as_slice()) + self.set(key, value) }) } fn del_index( &mut self, + table_codec: &mut TableCodec, table_name: &str, index: &Index, tuple_id: &TupleId, @@ -370,78 +374,84 @@ pub trait Transaction: Sized { if matches!(index.ty, IndexType::PrimaryKey { .. }) { return Ok(()); } - unsafe { &*self.table_codec() }.with_index_key( - table_name, - index, - Some(tuple_id), - |key| self.remove(key), - )?; + table_codec.with_index(table_name, index, Some(tuple_id), |key, _| self.remove(key))?; Ok(()) } - fn append_tuple( + fn append_tuple( &mut self, + table_codec: &mut TableCodec, table_name: &str, tuple: Tuple, - serializers: &[TupleValueSerializableImpl], + serializers: I, is_overwrite: bool, - ) -> Result<(), DatabaseError> { + ) -> Result<(), DatabaseError> + where + I: IntoIterator, + S: Borrow, + { let tuple_id = tuple.pk.as_ref().ok_or(DatabaseError::PrimaryKeyNotFound)?; - let value = tuple.serialize_to(serializers, unsafe { &*self.table_codec() }.arena())?; - - unsafe { &*self.table_codec() }.with_tuple_key(table_name, tuple_id, |key| { - if !is_overwrite && self.exists(key)? { - return Err(DatabaseError::DuplicatePrimaryKey); - } - self.set(key, value.as_slice()) - }) + let mut serializers = serializers.into_iter(); + let mut write_value = + |tuple: &Tuple, value: &mut Bytes| tuple.serialize_to(&mut serializers, value); + table_codec.with_tuple( + table_name, + tuple_id, + Some((&tuple, &mut write_value)), + |key, value| { + if !is_overwrite && self.exists(key)? { + return Err(DatabaseError::DuplicatePrimaryKey); + } + self.set(key, value) + }, + ) } - fn remove_tuple(&mut self, table_name: &str, tuple_id: &TupleId) -> Result<(), DatabaseError> { - unsafe { &*self.table_codec() }.with_tuple_key(table_name, tuple_id, |key| self.remove(key)) + fn remove_tuple( + &mut self, + table_codec: &mut TableCodec, + table_name: &str, + tuple_id: &TupleId, + ) -> Result<(), DatabaseError> { + table_codec.with_tuple(table_name, tuple_id, None, |key, _| self.remove(key)) } fn rewrite_table_metadata( &mut self, - table_cache: &TableCache, + table_codec: &mut TableCodec, + arena: &impl MetaArena, table: &TableCatalog, ) -> Result<(), DatabaseError> { let table_name = table.name().clone(); - unsafe { &*self.table_codec() }.with_columns_bound(table_name.as_ref(), |min, max| { + table_codec.with_columns_bound(table_name.as_ref(), |min, max| { self.remove_range(Bound::Included(min), Bound::Included(max)) })?; - unsafe { &*self.table_codec() } - .with_index_meta_bound(table_name.as_ref(), |min, max| { - self.remove_range(Bound::Included(min), Bound::Included(max)) - })?; + table_codec.with_index_meta_bound(table_name.as_ref(), |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; - let mut reference_tables = ReferenceTables::new(); - let _ = reference_tables.push_or_replace(table.name()); - for column in table.columns() { - let value = unsafe { &*self.table_codec() } - .encode_column_value(column, &mut reference_tables)?; - unsafe { &*self.table_codec() } - .with_column_key(column, |key| self.set(key, value.as_slice()))?; + for column in table.columns().map(|column| arena.column(*column)) { + table_codec.with_column(column, true, |key, value| self.set(key, value))?; } for index_meta in table.indexes() { - let value = unsafe { &*self.table_codec() }.encode_index_meta_value(index_meta)?; - unsafe { &*self.table_codec() }.with_index_meta_key( - table.name(), + let index_meta = arena.index(*index_meta); + table_codec.with_index_meta( + table_name.as_ref(), index_meta.id, - |key| self.set(key, value.as_slice()), + Some(index_meta), + |key, value| self.set(key, value), )?; } - table_cache.remove(table.name()); - Ok(()) } #[allow(clippy::too_many_arguments)] fn change_column( &mut self, - table_cache: &TableCache, + table_codec: &mut TableCodec, + plan_arena: &mut PlanArena, table_name: &TableName, old_column_name: &str, new_column_name: &str, @@ -450,10 +460,9 @@ pub trait Transaction: Sized { not_null_change: &NotNullChange, ) -> Result { let table = self - .table(table_cache, table_name.clone())? - .cloned() + .load_table(table_codec, plan_arena, table_name.clone())? .ok_or(DatabaseError::TableNotFound)?; - let mut column_refs = Vec::with_capacity(table.columns_len()); + let mut column_catalogs = Vec::with_capacity(table.columns_len()); let mut found = false; if old_column_name != new_column_name && table.get_column_by_name(new_column_name).is_some() @@ -461,7 +470,7 @@ pub trait Transaction: Sized { return Err(DatabaseError::DuplicateColumn(new_column_name.to_string())); } - for column in table.columns() { + for column in table.columns().map(|column| plan_arena.column(*column)) { let mut new_column = ColumnCatalog::clone(column); if column.name() == old_column_name { found = true; @@ -473,6 +482,7 @@ pub trait Transaction: Sized { new_column.desc_mut().default = Some(ScalarExpression::type_cast( default_expr, Cow::Borrowed(new_data_type), + plan_arena, )?); } } @@ -493,308 +503,317 @@ pub trait Transaction: Sized { new_column.set_nullable(false); } } - column_refs.push(ColumnRef::from(new_column)); + column_catalogs.push(new_column); } if !found { return Err(DatabaseError::column_not_found(old_column_name.to_string())); } - let temp_table = TableCatalog::reload(table_name.clone(), column_refs.clone(), vec![])?; + let temp_table = TableCatalog::reload( + table_name.clone(), + column_catalogs.clone().into_iter(), + std::iter::empty(), + plan_arena, + )?; let index_metas = table .indexes() .map(|index_meta| { - Ok(Arc::new(IndexMeta { + let index_meta = plan_arena.index(*index_meta); + Ok(IndexMeta { id: index_meta.id, column_ids: index_meta.column_ids.clone(), table_name: table_name.clone(), pk_ty: temp_table.primary_keys_type().clone(), - value_ty: index_value_type(&temp_table, &index_meta.column_ids)?, + value_ty: index_value_type(&temp_table, plan_arena, &index_meta.column_ids)?, name: index_meta.name.clone(), ty: index_meta.ty, - })) + }) }) .collect::, DatabaseError>>()?; - let updated_table = TableCatalog::reload(table_name.clone(), column_refs, index_metas)?; - self.rewrite_table_metadata(table_cache, &updated_table)?; + let updated_table = TableCatalog::reload( + table_name.clone(), + column_catalogs.into_iter(), + index_metas.into_iter(), + plan_arena, + )?; + self.rewrite_table_metadata(table_codec, plan_arena, &updated_table)?; + table_codec.with_statistics_bound(table_name.as_ref(), |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; Ok(updated_table) } fn add_column( &mut self, - table_cache: &TableCache, + table_codec: &mut TableCodec, + plan_arena: &mut PlanArena, table_name: &TableName, column: &ColumnCatalog, if_not_exists: bool, - ) -> Result { - if let Some(mut table) = self.table(table_cache, table_name.clone())?.cloned() { - if !column.nullable() && column.default_value()?.is_none() { - return Err(DatabaseError::NeedNullAbleOrDefault); - } + ) -> Result<(TableCatalog, ColumnId), DatabaseError> { + let mut table = self + .load_table(table_codec, plan_arena, table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + if !column.nullable() && column.default_value()?.is_none() { + return Err(DatabaseError::NeedNullAbleOrDefault); + } - for col in table.columns() { - if col.name() == column.name() { - return if if_not_exists { - Ok(col.id().unwrap()) - } else { - Err(DatabaseError::DuplicateColumn(column.name().to_string())) - }; - } - } - let mut generator = Generator::new(); - let col_id = table.add_column(column.clone(), &mut generator)?; - - if column.desc().is_unique() { - let meta_ref = table.add_index_meta( - format!("uk_{}", column.name()), - vec![col_id], - IndexType::Unique, - )?; - let value = unsafe { &*self.table_codec() }.encode_index_meta_value(meta_ref)?; - unsafe { &*self.table_codec() }.with_index_meta_key( - table_name, - meta_ref.id, - |key| self.set(key, value.as_slice()), - )?; + for col in table.columns().map(|column| plan_arena.column(*column)) { + if col.name() == column.name() { + return if if_not_exists { + Ok((table, col.id().unwrap())) + } else { + Err(DatabaseError::DuplicateColumn(column.name().to_string())) + }; } + } + let mut generator = Generator::new(); + let col_id = table.add_column(column.clone(), &mut generator, plan_arena)?; + + if column.desc().is_unique() { + let meta_ref = table.add_index_meta( + format!("uk_{}", column.name()), + vec![col_id], + IndexType::Unique, + plan_arena, + )?; + let meta = plan_arena.index(meta_ref); + table_codec.with_index_meta(table_name, meta.id, Some(meta), |key, value| { + self.set(key, value) + })?; + } - let column = table.get_column_by_id(&col_id).unwrap(); - let value = unsafe { &*self.table_codec() } - .encode_column_value(column, &mut ReferenceTables::new())?; - unsafe { &*self.table_codec() } - .with_column_key(column, |key| self.set(key, value.as_slice()))?; - table_cache.remove(table_name); + let column = plan_arena.column(table.get_column_by_id(&col_id).unwrap()); + table_codec.with_column(column, true, |key, value| self.set(key, value))?; - Ok(col_id) - } else { - Err(DatabaseError::TableNotFound) - } + Ok((table, col_id)) } fn drop_column( &mut self, - table_cache: &TableCache, - meta_cache: &StatisticsMetaCache, + table_codec: &mut TableCodec, + plan_arena: &mut PlanArena, table_name: &TableName, column_name: &str, - ) -> Result<(), DatabaseError> { - if let Some(table_catalog) = self.table(table_cache, table_name.clone())?.cloned() { - let column = table_catalog.get_column_by_name(column_name).unwrap(); - - unsafe { &*self.table_codec() }.with_column_key(column, |key| self.remove(key))?; - - for index_meta in table_catalog.indexes.iter() { - if !index_meta.column_ids.contains(&column.id().unwrap()) { - continue; - } - unsafe { &*self.table_codec() }.with_index_meta_key( - table_name, - index_meta.id, - |key| self.remove(key), - )?; - - unsafe { &*self.table_codec() }.with_index_bound( - table_name, - index_meta.id, - |min, max| self.remove_range(Bound::Included(min), Bound::Included(max)), - )?; + ) -> Result { + let table_catalog = self + .load_table(table_codec, plan_arena, table_name.clone())? + .ok_or(DatabaseError::TableNotFound)?; + let column_id = { + let column = plan_arena.column(table_catalog.get_column_by_name(column_name).unwrap()); + let column_id = column.id().unwrap(); + table_codec.with_column(column, false, |key, _| self.remove(key))?; + column_id + }; - self.remove_statistics_meta(meta_cache, table_name, index_meta.id)?; + for index_meta in table_catalog.indexes.iter() { + let index_meta = plan_arena.index(*index_meta); + if !index_meta.column_ids.contains(&column_id) { + continue; } - table_cache.remove(table_name); + table_codec + .with_index_meta(table_name, index_meta.id, None, |key, _| self.remove(key))?; - Ok(()) - } else { - Err(DatabaseError::TableNotFound) + table_codec.with_index_bound(table_name, index_meta.id, |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; + + self.remove_statistics_meta(table_codec, table_name, index_meta.id)?; } + self.load_table(table_codec, plan_arena, table_name.clone())? + .ok_or(DatabaseError::TableNotFound) } fn create_view( &mut self, - view_cache: &ViewCache, + table_codec: &mut TableCodec, + arena: &PlanArena, view: View, or_replace: bool, - ) -> Result<(), DatabaseError> { - let value = unsafe { &*self.table_codec() }.encode_view_value(&view)?; - - let already_exists = - unsafe { &*self.table_codec() }.with_view_key(&view.name, |key| self.exists(key))?; + ) -> Result { + let already_exists = table_codec.with_view(&view.name, |key, _| self.exists(key))?; if !or_replace && already_exists { return Err(DatabaseError::ViewExists); } if !already_exists { - self.check_name_hash(&view.name)?; + self.check_name_hash(table_codec, &view.name)?; } - unsafe { &*self.table_codec() } - .with_view_key(&view.name, |key| self.set(key, value.as_slice()))?; - let _ = view_cache.put(view.name.clone(), view); + table_codec.with_view_value(&view.name, &view, arena, |key, value| self.set(key, value))?; - Ok(()) + Ok(view) } fn create_table( &mut self, - table_cache: &TableCache, + table_codec: &mut TableCodec, + plan_arena: &mut PlanArena, table_name: TableName, columns: Vec, if_not_exists: bool, - ) -> Result { - let mut table_catalog = TableCatalog::new(table_name.clone(), columns)?; + ) -> Result, DatabaseError> { + let mut table_catalog = TableCatalog::new(table_name.clone(), columns, plan_arena)?; for (_, column) in table_catalog.primary_keys() { - TableCodec::check_primary_key_type(column.datatype())?; + TableCodec::check_primary_key_type(plan_arena.column(*column).datatype())?; } - let value = unsafe { &*self.table_codec() } - .encode_root_table_value(&TableMeta::empty(table_name.clone()))?; - let exists = unsafe { &*self.table_codec() } - .with_root_table_key(table_name.as_ref(), |key| self.exists(key))?; + let exists = + table_codec.with_root_table(table_name.as_ref(), None, |key, _| self.exists(key))?; if exists { if if_not_exists { - return Ok(table_name); + return Ok(None); } return Err(DatabaseError::TableExists); } - self.check_name_hash(&table_name)?; - self.create_index_meta_from_column(&mut table_catalog)?; - unsafe { &*self.table_codec() } - .with_root_table_key(table_name.as_ref(), |key| self.set(key, value.as_slice()))?; - - let mut reference_tables = ReferenceTables::new(); - for column in table_catalog.columns() { - let value = unsafe { &*self.table_codec() } - .encode_column_value(column, &mut reference_tables)?; - unsafe { &*self.table_codec() } - .with_column_key(column, |key| self.set(key, value.as_slice()))?; + self.check_name_hash(table_codec, &table_name)?; + self.create_index_meta_from_column(table_codec, plan_arena, &mut table_catalog)?; + let table_meta = TableMeta::empty(table_name.clone()); + table_codec.with_root_table(table_name.as_ref(), Some(&table_meta), |key, value| { + self.set(key, value) + })?; + + for column in table_catalog + .columns() + .map(|column| plan_arena.column(*column)) + { + table_codec.with_column(column, true, |key, value| self.set(key, value))?; } - debug_assert_eq!(reference_tables.len(), 1); - table_cache.put(table_name.clone(), table_catalog); - Ok(table_name) + Ok(Some(table_catalog)) } - fn check_name_hash(&mut self, table_name: &TableName) -> Result<(), DatabaseError> { - if unsafe { &*self.table_codec() } - .with_table_hash_key(table_name, |key| self.exists(key))? - { + fn check_name_hash( + &mut self, + table_codec: &mut TableCodec, + table_name: &TableName, + ) -> Result<(), DatabaseError> { + if table_codec.with_table_hash(table_name, |key, _| self.exists(key))? { return Err(DatabaseError::DuplicateSourceHash(table_name.to_string())); } - unsafe { &*self.table_codec() }.with_table_hash_key(table_name, |key| self.set(key, &[])) + table_codec.with_table_hash(table_name, |key, _| self.set(key, &[])) } - fn drop_name_hash(&mut self, table_name: &TableName) -> Result<(), DatabaseError> { - unsafe { &*self.table_codec() }.with_table_hash_key(table_name, |key| self.remove(key)) + fn drop_name_hash( + &mut self, + table_codec: &mut TableCodec, + table_name: &TableName, + ) -> Result<(), DatabaseError> { + table_codec.with_table_hash(table_name, |key, _| self.remove(key)) } fn drop_view( &mut self, - view_cache: &ViewCache, + table_codec: &mut TableCodec, view_name: TableName, if_exists: bool, - ) -> Result<(), DatabaseError> { - self.drop_name_hash(&view_name)?; - let exists = - unsafe { &*self.table_codec() }.with_view_key(&view_name, |key| self.exists(key))?; + ) -> Result { + self.drop_name_hash(table_codec, &view_name)?; + let exists = table_codec.with_view(&view_name, |key, _| self.exists(key))?; if !exists { if if_exists { - return Ok(()); + return Ok(false); } else { return Err(DatabaseError::ViewNotFound); } } - unsafe { &*self.table_codec() } - .with_view_key(view_name.as_ref(), |key| self.remove(key))?; - view_cache.remove(&view_name); + table_codec.with_view(view_name.as_ref(), |key, _| self.remove(key))?; - Ok(()) + Ok(true) } fn drop_index( &mut self, - table_cache: &TableCache, - meta_cache: &StatisticsMetaCache, + table_codec: &mut TableCodec, + plan_arena: &mut PlanArena, table_name: TableName, index_name: &str, if_exists: bool, - ) -> Result<(), DatabaseError> { + ) -> Result, DatabaseError> { let table = self - .table(table_cache, table_name.clone())? + .load_table(table_codec, plan_arena, table_name.clone())? .ok_or(DatabaseError::TableNotFound)?; - let Some(index_meta) = table.indexes.iter().find(|index| index.name == index_name) else { + let Some(index_meta_ref) = table + .indexes + .iter() + .copied() + .find(|index| plan_arena.index(*index).name == index_name) + else { if if_exists { - return Ok(()); + return Ok(None); } else { return Err(DatabaseError::TableNotFound); } }; + let index_meta = plan_arena.index(index_meta_ref); match index_meta.ty { IndexType::PrimaryKey { .. } => return Err(DatabaseError::InvalidIndex), IndexType::Unique | IndexType::Normal | IndexType::Composite => (), } let index_id = index_meta.id; - unsafe { &*self.table_codec() }.with_index_meta_key( - table_name.as_ref(), - index_id, - |key| self.remove(key), - )?; - - unsafe { &*self.table_codec() }.with_index_bound( - table_name.as_ref(), - index_id, - |min, max| self.remove_range(Bound::Included(min), Bound::Included(max)), - )?; + table_codec.with_index_meta(table_name.as_ref(), index_id, None, |key, _| { + self.remove(key) + })?; - self.remove_statistics_meta(meta_cache, &table_name, index_id)?; + table_codec.with_index_bound(table_name.as_ref(), index_id, |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; - table_cache.remove(&table_name); + self.remove_statistics_meta(table_codec, &table_name, index_id)?; - Ok(()) + self.load_table(table_codec, plan_arena, table_name.clone())? + .map(|table| (table, index_id)) + .map(Some) + .ok_or(DatabaseError::TableNotFound) } fn drop_table( &mut self, - table_cache: &TableCache, + table_codec: &mut TableCodec, table_name: TableName, if_exists: bool, - ) -> Result<(), DatabaseError> { - if self.table(table_cache, table_name.clone())?.is_none() { + ) -> Result { + let exists = + table_codec.with_root_table(table_name.as_ref(), None, |key, _| self.exists(key))?; + if !exists { if if_exists { - return Ok(()); + return Ok(false); } else { return Err(DatabaseError::TableNotFound); } } - self.drop_name_hash(&table_name)?; - self.drop_data(table_name.as_ref())?; + self.drop_name_hash(table_codec, &table_name)?; + self.drop_data(table_codec, table_name.as_ref())?; - unsafe { &*self.table_codec() }.with_columns_bound(table_name.as_ref(), |min, max| { + table_codec.with_columns_bound(table_name.as_ref(), |min, max| { self.remove_range(Bound::Included(min), Bound::Included(max)) })?; - unsafe { &*self.table_codec() } - .with_index_meta_bound(table_name.as_ref(), |min, max| { - self.remove_range(Bound::Included(min), Bound::Included(max)) - })?; + table_codec.with_index_meta_bound(table_name.as_ref(), |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; - unsafe { &*self.table_codec() } - .with_root_table_key(table_name.as_ref(), |key| self.remove(key))?; - table_cache.remove(&table_name); + table_codec.with_root_table(table_name.as_ref(), None, |key, _| self.remove(key))?; - Ok(()) + Ok(true) } - fn drop_data(&mut self, table_name: &str) -> Result<(), DatabaseError> { - unsafe { &*self.table_codec() }.with_tuple_bound(table_name, |min, max| { + fn drop_data( + &mut self, + table_codec: &mut TableCodec, + table_name: &str, + ) -> Result<(), DatabaseError> { + table_codec.with_tuple_bound(table_name, |min, max| { self.remove_range(Bound::Included(min), Bound::Included(max)) })?; - unsafe { &*self.table_codec() }.with_all_index_bound(table_name, |min, max| { + table_codec.with_all_index_bound(table_name, |min, max| { self.remove_range(Bound::Included(min), Bound::Included(max)) })?; - unsafe { &*self.table_codec() }.with_statistics_bound(table_name, |min, max| { + table_codec.with_statistics_bound(table_name, |min, max| { self.remove_range(Bound::Included(min), Bound::Included(max)) }) } @@ -807,35 +826,48 @@ pub trait Transaction: Sized { table_functions: &'a TableFunctions, view_name: TableName, ) -> Result, DatabaseError> { - if let Some(view) = view_cache.get(&view_name) { - return Ok(Some(view)); - } - unsafe { &*self.table_codec() }.with_view_key(&view_name, |key| { + let _ = (table_cache, scala_functions, table_functions); + Ok(view_cache.get(&view_name)) + } + + fn load_view( + &self, + table_codec: &mut TableCodec, + table_cache: &TableCache, + table_arena: &TableArenaCell, + scala_functions: &ScalaFunctions, + table_functions: &TableFunctions, + view_name: TableName, + ) -> Result, DatabaseError> { + table_codec.with_view(&view_name, |key, _| { let Some(bytes) = self.get_borrowed(key)? else { return Ok(None); }; - Ok(Some(view_cache.get_or_insert(view_name.clone(), |_| { - TableCodec::decode_view( - bytes.as_ref(), - (self, table_cache), - scala_functions, - table_functions, - ) - })?)) + TableCodec::decode_view( + bytes.as_ref(), + (self, table_cache), + scala_functions, + table_functions, + table_arena.borrow_mut(), + ) + .map(Some) }) } fn views<'a>( &'a self, + table_codec: &mut TableCodec, table_cache: &'a TableCache, + table_arena: &'a TableArenaCell, scala_functions: &'a ScalaFunctions, table_functions: &'a TableFunctions, ) -> Result, DatabaseError> { - unsafe { &*self.table_codec() }.with_view_bound(|min, max| { + table_codec.with_view_bound(|min, max| { Ok(ViewIter { iter: self.range(Bound::Included(min), Bound::Included(max))?, transaction: self, table_cache, + table_arena, scala_functions, table_functions, }) @@ -843,26 +875,31 @@ pub trait Transaction: Sized { } fn table<'a>( - &'a self, + &self, table_cache: &'a TableCache, table_name: TableName, ) -> Result, DatabaseError> { - if let Some(table) = table_cache.get(&table_name) { - return Ok(Some(table)); - } + Ok(table_cache.get(&table_name)) + } - // `TableCache` is not theoretically used in `table_collect` because ColumnCatalog should not depend on other Column - self.table_collect(&table_name)? + fn load_table( + &self, + table_codec: &mut TableCodec, + arena: &mut impl MetaArena, + table_name: TableName, + ) -> Result, DatabaseError> { + self.table_collect(table_codec, &table_name)? .map(|(columns, indexes)| { - table_cache.get_or_insert(table_name.clone(), |_| { - TableCatalog::reload(table_name, columns, indexes) - }) + TableCatalog::reload(table_name, columns.into_iter(), indexes.into_iter(), arena) }) .transpose() } - fn tables<'a>(&'a self) -> Result, DatabaseError> { - unsafe { &*self.table_codec() }.with_root_table_bound(|min, max| { + fn tables<'a>( + &'a self, + table_codec: &mut TableCodec, + ) -> Result, DatabaseError> { + table_codec.with_root_table_bound(|min, max| { Ok(TableIter { iter: self.range(Bound::Included(min), Bound::Included(max))?, }) @@ -871,135 +908,114 @@ pub trait Transaction: Sized { fn save_statistics_meta( &mut self, - meta_cache: &StatisticsMetaCache, + table_codec: &mut TableCodec, table_name: &TableName, statistics_meta: StatisticsMeta, ) -> Result<(), DatabaseError> { let index_id = statistics_meta.index_id(); - let cached_meta = statistics_meta.clone(); let (root, buckets, cm_sketch) = statistics_meta.into_parts(); - let value = unsafe { &*self.table_codec() }.encode_statistics_meta_value(&root)?; - unsafe { &*self.table_codec() }.with_statistics_meta_key( + table_codec.with_statistics_meta( table_name.as_ref(), index_id, - |key| self.set(key, value.as_slice()), + Some(&root), + |key, value| self.set(key, value), )?; let (sketch_meta, sketch_pages) = cm_sketch.into_storage_parts(COUNT_MIN_SKETCH_STORAGE_PAGE_LEN); - let value = - unsafe { &*self.table_codec() }.encode_statistics_sketch_meta_value(&sketch_meta)?; - unsafe { &*self.table_codec() }.with_statistics_sketch_meta_key( + table_codec.with_statistics_sketch_meta( table_name.as_ref(), index_id, - |key| self.set(key, value.as_slice()), + Some(&sketch_meta), + |key, value| self.set(key, value), )?; for sketch_page in sketch_pages { - let value = unsafe { &*self.table_codec() } - .encode_statistics_sketch_page_value(&sketch_page)?; - unsafe { &*self.table_codec() }.with_statistics_sketch_page_key( + table_codec.with_statistics_sketch_page( table_name.as_ref(), index_id, &sketch_page, - |key| self.set(key, value.as_slice()), + true, + |key, value| self.set(key, value), )?; } for (ordinal, bucket) in buckets.iter().enumerate() { - let value = unsafe { &*self.table_codec() }.encode_statistics_bucket_value(bucket)?; - unsafe { &*self.table_codec() }.with_statistics_bucket_key( + table_codec.with_statistics_bucket( table_name.as_ref(), index_id, ordinal as u32, - |key| self.set(key, value.as_slice()), + Some(bucket), + |key, value| self.set(key, value), )?; } - meta_cache.put((table_name.clone(), index_id), Some(cached_meta)); - Ok(()) } fn statistics_meta( &self, + table_codec: &mut TableCodec, table_name: &str, index_id: IndexId, ) -> Result, DatabaseError> { - unsafe { &*self.table_codec() }.with_statistics_index_bound( - table_name, - index_id, - |min, max| { - let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; - let mut root = None; - let mut buckets = Vec::new(); - let mut sketch_meta = None; - let mut sketch_pages = Vec::::new(); - - while let Some((key, value)) = iter.try_next()? { - match unsafe { &*self.table_codec() }.decode_statistics_codec_type(key)? { - StatisticsCodecType::Root => { - root = Some(TableCodec::decode_statistics_meta::(value)?); - } - StatisticsCodecType::SketchMeta => { - sketch_meta = - Some(TableCodec::decode_statistics_sketch_meta::(value)?); - } - StatisticsCodecType::SketchPage => { - sketch_pages - .push(TableCodec::decode_statistics_sketch_page::(value)?); - } - StatisticsCodecType::Bucket => { - buckets.push(TableCodec::decode_statistics_bucket::(value)?); - } + table_codec.with_statistics_index_bound(table_name, index_id, |min, max| { + let mut iter = self.range(Bound::Included(min), Bound::Included(max))?; + let mut root = None; + let mut buckets = Vec::new(); + let mut sketch_meta = None; + let mut sketch_pages = Vec::::new(); + + while let Some((key, value)) = iter.try_next()? { + match TableCodec::decode_statistics_codec_type(key)? { + StatisticsCodecType::Root => { + root = Some(TableCodec::decode_statistics_meta::(value)?); + } + StatisticsCodecType::SketchMeta => { + sketch_meta = + Some(TableCodec::decode_statistics_sketch_meta::(value)?); + } + StatisticsCodecType::SketchPage => { + sketch_pages + .push(TableCodec::decode_statistics_sketch_page::(value)?); + } + StatisticsCodecType::Bucket => { + buckets.push(TableCodec::decode_statistics_bucket::(value)?); } } + } - match (root, sketch_meta) { - (Some(root), Some(sketch_meta)) => { - let sketch = CountMinSketch::from_storage_parts(sketch_meta, sketch_pages)?; - StatisticsMeta::from_parts(root, buckets, sketch).map(Some) - } - (None, None) if buckets.is_empty() && sketch_pages.is_empty() => Ok(None), - _ => Err(DatabaseError::InvalidValue( - "statistics meta is incomplete".to_string(), - )), + match (root, sketch_meta) { + (Some(root), Some(sketch_meta)) => { + let sketch = CountMinSketch::from_storage_parts(sketch_meta, sketch_pages)?; + StatisticsMeta::from_parts(root, buckets, sketch).map(Some) } - }, - ) + (None, None) if buckets.is_empty() && sketch_pages.is_empty() => Ok(None), + _ => Err(DatabaseError::InvalidValue( + "statistics meta is incomplete".to_string(), + )), + } + }) } fn remove_statistics_meta( &mut self, - meta_cache: &StatisticsMetaCache, + table_codec: &mut TableCodec, table_name: &TableName, index_id: IndexId, ) -> Result<(), DatabaseError> { - unsafe { &*self.table_codec() }.with_statistics_index_bound( - table_name.as_ref(), - index_id, - |min, max| self.remove_range(Bound::Included(min), Bound::Included(max)), - )?; - - meta_cache.remove(&(table_name.clone(), index_id)); + table_codec.with_statistics_index_bound(table_name.as_ref(), index_id, |min, max| { + self.remove_range(Bound::Included(min), Bound::Included(max)) + })?; Ok(()) } - fn meta_loader<'a>( - &'a self, - meta_cache: &'a StatisticsMetaCache, - ) -> StatisticMetaLoader<'a, Self> - where - Self: Sized, - { - StatisticMetaLoader::new(self, meta_cache) - } - #[allow(clippy::type_complexity)] fn table_collect( &self, + table_codec: &mut TableCodec, table_name: &TableName, - ) -> Result, Vec)>, DatabaseError> { - unsafe { &*self.table_codec() }.with_table_bound(table_name, |table_min, table_max| { + ) -> Result, Vec)>, DatabaseError> { + table_codec.with_table_bound(table_name, |table_min, table_max| { let mut column_iter = self.range(Bound::Included(table_min), Bound::Included(table_max))?; @@ -1017,7 +1033,7 @@ pub trait Transaction: Sized { &reference_tables, )?); } else { - index_metas.push(Arc::new(TableCodec::decode_index_meta::(value)?)); + index_metas.push(TableCodec::decode_index_meta::(value)?); } } @@ -1027,43 +1043,47 @@ pub trait Transaction: Sized { fn create_index_meta_from_column( &mut self, + table_codec: &mut TableCodec, + arena: &mut impl MetaArena, table: &mut TableCatalog, ) -> Result<(), DatabaseError> { let table_name = table.name.clone(); let mut primary_keys = Vec::new(); - let schema_ref = table.schema_ref().clone(); - for col in schema_ref.iter() { + let mut i = 0; + while i < table.columns_len() { + let column = table.column_ref(i).unwrap(); + i += 1; + let col = arena.column(column); let col_id = col.id().ok_or(DatabaseError::PrimaryKeyNotFound)?; - let index_ty = if let Some(i) = col.desc().primary() { - primary_keys.push((i, col_id)); + let index_ty = if let Some(primary_key_index) = col.desc().primary() { + primary_keys.push((primary_key_index, col_id)); continue; } else if col.desc().is_unique() { IndexType::Unique } else { continue; }; - let meta_ref = - table.add_index_meta(format!("uk_{}_index", col.name()), vec![col_id], index_ty)?; - let value = unsafe { &*self.table_codec() }.encode_index_meta_value(meta_ref)?; - unsafe { &*self.table_codec() }.with_index_meta_key( - &table_name, - meta_ref.id, - |key| self.set(key, value.as_slice()), - )?; + let index_name = format!("uk_{}_index", col.name()); + let meta_ref = table.add_index_meta(index_name, vec![col_id], index_ty, arena)?; + let meta = arena.index(meta_ref); + table_codec.with_index_meta(&table_name, meta.id, Some(meta), |key, value| { + self.set(key, value) + })?; } let primary_keys = table .primary_keys() .iter() - .map(|(_, column)| column.id().unwrap()) + .map(|(_, column)| arena.column(*column).id().unwrap()) .collect_vec(); let pk_index_ty = IndexType::PrimaryKey { is_multiple: primary_keys.len() != 1, }; - let meta_ref = table.add_index_meta("pk_index".to_string(), primary_keys, pk_index_ty)?; - let value = unsafe { &*self.table_codec() }.encode_index_meta_value(meta_ref)?; - unsafe { &*self.table_codec() }.with_index_meta_key(&table_name, meta_ref.id, |key| { - self.set(key, value.as_slice()) + let meta_ref = + table.add_index_meta("pk_index".to_string(), primary_keys, pk_index_ty, arena)?; + let meta = arena.index(meta_ref); + table_codec.with_index_meta(&table_name, meta.id, Some(meta), |key, value| { + self.set(key, value) })?; Ok(()) @@ -1128,7 +1148,7 @@ pub trait Transaction: Sized { fn commit(self) -> Result<(), DatabaseError>; } -fn owned_bound(bound: Bound<&[u8]>) -> Bound { +pub(crate) fn owned_bound(bound: Bound<&[u8]>) -> Bound { match bound { Bound::Included(bytes) => Bound::Included(bytes.to_vec()), Bound::Excluded(bytes) => Bound::Excluded(bytes.to_vec()), @@ -1146,14 +1166,7 @@ pub(crate) fn reuse_bound_as_excluded(bound: &mut Bound, key: &[u8]) { *bound = Bound::Excluded(bytes); } -fn same_projection_column(left: &ColumnRef, right: &ColumnRef) -> bool { - match (left.id(), right.id()) { - (Some(left), Some(right)) => left == right, - _ => left.name() == right.name(), - } -} - -fn bytes_bound_as_slice(bound: &Bound) -> Bound<&[u8]> { +pub(crate) fn bytes_bound_as_slice(bound: &Bound) -> Bound<&[u8]> { match bound { Bound::Included(bytes) => Bound::Included(bytes.as_slice()), Bound::Excluded(bytes) => Bound::Excluded(bytes.as_slice()), @@ -1181,6 +1194,7 @@ fn encode_bound_key(buffer: &mut Bytes, key: &[u8], is_upper: bool) { #[inline] fn encode_bound<'a>( + table_codec: &mut TableCodec, bound: &Bound, is_upper: bool, buffer: &'a mut Bytes, @@ -1189,11 +1203,11 @@ fn encode_bound<'a>( ) -> Result, DatabaseError> { match bound { Bound::Included(val) => { - inner.bound_key(params, val, is_upper, buffer)?; + inner.bound_key(table_codec, params, val, is_upper, buffer)?; Ok(Bound::Included(buffer.as_slice())) } Bound::Excluded(val) => { - inner.bound_key(params, val, is_upper, buffer)?; + inner.bound_key(table_codec, params, val, is_upper, buffer)?; Ok(Bound::Excluded(buffer.as_slice())) } Bound::Unbounded => Ok(Bound::Unbounded), @@ -1203,6 +1217,7 @@ fn encode_bound<'a>( trait IndexImpl { fn index_lookup_into( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, key: &[u8], value: &[u8], @@ -1211,6 +1226,7 @@ trait IndexImpl { fn eq_to_res<'a>( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, @@ -1220,6 +1236,7 @@ trait IndexImpl { fn bound_key( &self, + table_codec: &mut TableCodec, params: &IndexImplParams, value: &DataValue, is_upper: bool, @@ -1277,7 +1294,8 @@ impl TupleMapping { struct IndexImplParams<'a, T: Transaction> { index_meta: IndexMetaRef, - table_name: &'a str, + meta_arena: &'a dyn MetaArena, + table_name: TableName, deserializers: Vec, total_len: usize, tx: &'a T, @@ -1287,21 +1305,23 @@ struct IndexImplParams<'a, T: Transaction> { impl IndexImplParams<'_, T> { #[inline] - pub(crate) fn value_ty(&self) -> &LogicalType { - &self.index_meta.value_ty + pub(crate) fn index_meta(&self) -> &IndexMeta { + self.meta_arena.index(self.index_meta) } #[inline] - pub(crate) fn table_codec(&self) -> *const TableCodec { - self.tx.table_codec() + pub(crate) fn value_ty(&self) -> &LogicalType { + &self.index_meta().value_ty } + #[inline] fn get_tuple_by_id_into( &self, + table_codec: &mut TableCodec, tuple_id: &TupleId, tuple: &mut Tuple, ) -> Result { - unsafe { &*self.table_codec() }.with_tuple_key_unchecked(self.table_name, tuple_id, |key| { + table_codec.with_tuple_unchecked(self.table_name.as_ref(), tuple_id, None, |key, _| { let Some(bytes) = self.tx.get_borrowed(key)? else { return Ok(false); }; @@ -1326,22 +1346,34 @@ enum IndexResult<'a, T: Transaction + 'a> { impl IndexImpl for IndexImplEnum { fn index_lookup_into( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, key: &[u8], value: &[u8], params: &IndexImplParams, ) -> Result<(), DatabaseError> { match self { - IndexImplEnum::PrimaryKey(inner) => inner.index_lookup_into(tuple, key, value, params), - IndexImplEnum::Unique(inner) => inner.index_lookup_into(tuple, key, value, params), - IndexImplEnum::Normal(inner) => inner.index_lookup_into(tuple, key, value, params), - IndexImplEnum::Composite(inner) => inner.index_lookup_into(tuple, key, value, params), - IndexImplEnum::Covered(inner) => inner.index_lookup_into(tuple, key, value, params), + IndexImplEnum::PrimaryKey(inner) => { + inner.index_lookup_into(table_codec, tuple, key, value, params) + } + IndexImplEnum::Unique(inner) => { + inner.index_lookup_into(table_codec, tuple, key, value, params) + } + IndexImplEnum::Normal(inner) => { + inner.index_lookup_into(table_codec, tuple, key, value, params) + } + IndexImplEnum::Composite(inner) => { + inner.index_lookup_into(table_codec, tuple, key, value, params) + } + IndexImplEnum::Covered(inner) => { + inner.index_lookup_into(table_codec, tuple, key, value, params) + } } } fn eq_to_res<'a>( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, @@ -1350,36 +1382,47 @@ impl IndexImpl for IndexImplEnum { ) -> Result, DatabaseError> { match self { IndexImplEnum::PrimaryKey(inner) => { - inner.eq_to_res(tuple, value, params, encode_min, encode_max) + inner.eq_to_res(table_codec, tuple, value, params, encode_min, encode_max) } IndexImplEnum::Unique(inner) => { - inner.eq_to_res(tuple, value, params, encode_min, encode_max) + inner.eq_to_res(table_codec, tuple, value, params, encode_min, encode_max) } IndexImplEnum::Normal(inner) => { - inner.eq_to_res(tuple, value, params, encode_min, encode_max) + inner.eq_to_res(table_codec, tuple, value, params, encode_min, encode_max) } IndexImplEnum::Composite(inner) => { - inner.eq_to_res(tuple, value, params, encode_min, encode_max) + inner.eq_to_res(table_codec, tuple, value, params, encode_min, encode_max) } IndexImplEnum::Covered(inner) => { - inner.eq_to_res(tuple, value, params, encode_min, encode_max) + inner.eq_to_res(table_codec, tuple, value, params, encode_min, encode_max) } } } fn bound_key( &self, + table_codec: &mut TableCodec, params: &IndexImplParams, value: &DataValue, is_upper: bool, out: &mut Bytes, ) -> Result<(), DatabaseError> { match self { - IndexImplEnum::PrimaryKey(inner) => inner.bound_key(params, value, is_upper, out), - IndexImplEnum::Unique(inner) => inner.bound_key(params, value, is_upper, out), - IndexImplEnum::Normal(inner) => inner.bound_key(params, value, is_upper, out), - IndexImplEnum::Composite(inner) => inner.bound_key(params, value, is_upper, out), - IndexImplEnum::Covered(inner) => inner.bound_key(params, value, is_upper, out), + IndexImplEnum::PrimaryKey(inner) => { + inner.bound_key(table_codec, params, value, is_upper, out) + } + IndexImplEnum::Unique(inner) => { + inner.bound_key(table_codec, params, value, is_upper, out) + } + IndexImplEnum::Normal(inner) => { + inner.bound_key(table_codec, params, value, is_upper, out) + } + IndexImplEnum::Composite(inner) => { + inner.bound_key(table_codec, params, value, is_upper, out) + } + IndexImplEnum::Covered(inner) => { + inner.bound_key(table_codec, params, value, is_upper, out) + } } } } @@ -1387,12 +1430,13 @@ impl IndexImpl for IndexImplEnum { impl IndexImpl for PrimaryKeyIndexImpl { fn index_lookup_into( &self, + _: &mut TableCodec, tuple: &mut Tuple, key: &[u8], value: &[u8], params: &IndexImplParams, ) -> Result<(), DatabaseError> { - let tuple_id = TableCodec::decode_tuple_key(key, ¶ms.index_meta.pk_ty)?; + let tuple_id = TableCodec::decode_tuple_key(key, ¶ms.index_meta().pk_ty)?; TableCodec::decode_tuple_into( tuple, ¶ms.deserializers, @@ -1404,6 +1448,7 @@ impl IndexImpl for PrimaryKeyIndexImpl { fn eq_to_res<'a>( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, @@ -1411,10 +1456,11 @@ impl IndexImpl for PrimaryKeyIndexImpl { _: &mut Bytes, ) -> Result, DatabaseError> { let tuple_id = value.clone(); - let found = unsafe { &*params.table_codec() }.with_tuple_key_unchecked( - params.table_name, + let found = table_codec.with_tuple_unchecked( + params.table_name.as_ref(), value, - |key| { + None, + |key, _| { let Some(bytes) = params.tx.get_borrowed(key)? else { return Ok(false); }; @@ -1437,31 +1483,29 @@ impl IndexImpl for PrimaryKeyIndexImpl { fn bound_key( &self, + table_codec: &mut TableCodec, params: &IndexImplParams, value: &DataValue, _: bool, out: &mut Bytes, ) -> Result<(), DatabaseError> { - unsafe { &*params.table_codec() }.with_tuple_key_unchecked( - params.table_name, - value, - |key| { - out.clear(); - out.extend_from_slice(key); - Ok(()) - }, - ) + table_codec.with_tuple_unchecked(params.table_name.as_ref(), value, None, |key, _| { + out.clear(); + out.extend_from_slice(key); + Ok(()) + }) } } #[inline(always)] fn secondary_index_lookup( + table_codec: &mut TableCodec, tuple: &mut Tuple, bytes: &[u8], params: &IndexImplParams, ) -> Result<(), DatabaseError> { let tuple_id = TableCodec::decode_index(bytes)?; - if params.get_tuple_by_id_into(&tuple_id, tuple)? { + if params.get_tuple_by_id_into(table_codec, &tuple_id, tuple)? { Ok(()) } else { Err(DatabaseError::TupleIdNotFound(tuple_id)) @@ -1471,34 +1515,34 @@ fn secondary_index_lookup( impl IndexImpl for UniqueIndexImpl { fn index_lookup_into( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, _: &[u8], value: &[u8], params: &IndexImplParams, ) -> Result<(), DatabaseError> { - secondary_index_lookup(tuple, value, params) + secondary_index_lookup(table_codec, tuple, value, params) } fn eq_to_res<'a>( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, _: &mut Bytes, _: &mut Bytes, ) -> Result, DatabaseError> { - let index = Index::new(params.index_meta.id, value, IndexType::Unique); - let Some(bytes) = unsafe { &*params.table_codec() }.with_index_key( - params.table_name, - &index, - None, - |key| params.tx.get_borrowed(key), - )? + let index = Index::new(params.index_meta().id, value, IndexType::Unique); + let Some(bytes) = + table_codec.with_index(params.table_name.as_ref(), &index, None, |key, _| { + params.tx.get_borrowed(key) + })? else { return Ok(IndexResult::Miss); }; let tuple_id = TableCodec::decode_index(bytes.as_ref())?; - if params.get_tuple_by_id_into(&tuple_id, tuple)? { + if params.get_tuple_by_id_into(table_codec, &tuple_id, tuple)? { Ok(IndexResult::Hit) } else { Err(DatabaseError::TupleIdNotFound(tuple_id)) @@ -1507,14 +1551,15 @@ impl IndexImpl for UniqueIndexImpl { fn bound_key( &self, + table_codec: &mut TableCodec, params: &IndexImplParams, value: &DataValue, _: bool, out: &mut Bytes, ) -> Result<(), DatabaseError> { - let index = Index::new(params.index_meta.id, value, IndexType::Unique); + let index = Index::new(params.index_meta().id, value, IndexType::Unique); - unsafe { &*params.table_codec() }.with_index_key(params.table_name, &index, None, |key| { + table_codec.with_index(params.table_name.as_ref(), &index, None, |key, _| { out.clear(); out.extend_from_slice(key); Ok(()) @@ -1525,34 +1570,45 @@ impl IndexImpl for UniqueIndexImpl { impl IndexImpl for NormalIndexImpl { fn index_lookup_into( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, _: &[u8], value: &[u8], params: &IndexImplParams, ) -> Result<(), DatabaseError> { - secondary_index_lookup(tuple, value, params) + secondary_index_lookup(table_codec, tuple, value, params) } fn eq_to_res<'a>( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, encode_min: &mut Bytes, encode_max: &mut Bytes, ) -> Result, DatabaseError> { - eq_to_res_scope(tuple, self, value, params, encode_min, encode_max) + eq_to_res_scope( + table_codec, + tuple, + self, + value, + params, + encode_min, + encode_max, + ) } fn bound_key( &self, + table_codec: &mut TableCodec, params: &IndexImplParams, value: &DataValue, is_upper: bool, out: &mut Bytes, ) -> Result<(), DatabaseError> { - let index = Index::new(params.index_meta.id, value, IndexType::Normal); - unsafe { &*params.table_codec() }.with_index_key(params.table_name, &index, None, |key| { + let index = Index::new(params.index_meta().id, value, IndexType::Normal); + table_codec.with_index(params.table_name.as_ref(), &index, None, |key, _| { encode_bound_key(out, key, is_upper); Ok(()) }) @@ -1562,34 +1618,45 @@ impl IndexImpl for NormalIndexImpl { impl IndexImpl for CompositeIndexImpl { fn index_lookup_into( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, _: &[u8], value: &[u8], params: &IndexImplParams, ) -> Result<(), DatabaseError> { - secondary_index_lookup(tuple, value, params) + secondary_index_lookup(table_codec, tuple, value, params) } fn eq_to_res<'a>( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, encode_min: &mut Bytes, encode_max: &mut Bytes, ) -> Result, DatabaseError> { - eq_to_res_scope(tuple, self, value, params, encode_min, encode_max) + eq_to_res_scope( + table_codec, + tuple, + self, + value, + params, + encode_min, + encode_max, + ) } fn bound_key( &self, + table_codec: &mut TableCodec, params: &IndexImplParams, value: &DataValue, is_upper: bool, out: &mut Bytes, ) -> Result<(), DatabaseError> { - let index = Index::new(params.index_meta.id, value, IndexType::Composite); - unsafe { &*params.table_codec() }.with_index_key(params.table_name, &index, None, |key| { + let index = Index::new(params.index_meta().id, value, IndexType::Composite); + table_codec.with_index(params.table_name.as_ref(), &index, None, |key, _| { encode_bound_key(out, key, is_upper); Ok(()) }) @@ -1599,6 +1666,7 @@ impl IndexImpl for CompositeIndexImpl { impl IndexImpl for CoveredIndexImpl { fn index_lookup_into( &self, + _: &mut TableCodec, tuple: &mut Tuple, key: &[u8], value: &[u8], @@ -1625,24 +1693,34 @@ impl IndexImpl for CoveredIndexImpl { fn eq_to_res<'a>( &self, + table_codec: &mut TableCodec, tuple: &mut Tuple, value: &DataValue, params: &IndexImplParams<'a, T>, encode_min: &mut Bytes, encode_max: &mut Bytes, ) -> Result, DatabaseError> { - eq_to_res_scope(tuple, self, value, params, encode_min, encode_max) + eq_to_res_scope( + table_codec, + tuple, + self, + value, + params, + encode_min, + encode_max, + ) } fn bound_key( &self, + table_codec: &mut TableCodec, params: &IndexImplParams, value: &DataValue, is_upper: bool, out: &mut Bytes, ) -> Result<(), DatabaseError> { - let index = Index::new(params.index_meta.id, value, params.index_meta.ty); - unsafe { &*params.table_codec() }.with_index_key(params.table_name, &index, None, |key| { + let index = Index::new(params.index_meta().id, value, params.index_meta().ty); + table_codec.with_index(params.table_name.as_ref(), &index, None, |key, _| { encode_bound_key(out, key, is_upper); Ok(()) }) @@ -1651,6 +1729,7 @@ impl IndexImpl for CoveredIndexImpl { #[inline(always)] fn eq_to_res_scope<'a, T: Transaction + 'a>( + table_codec: &mut TableCodec, tuple: &mut Tuple, index_impl: &impl IndexImpl, value: &DataValue, @@ -1659,8 +1738,8 @@ fn eq_to_res_scope<'a, T: Transaction + 'a>( encode_max: &mut Bytes, ) -> Result, DatabaseError> { let _ = tuple; - index_impl.bound_key(params, value, false, encode_min)?; - index_impl.bound_key(params, value, true, encode_max)?; + index_impl.bound_key(table_codec, params, value, false, encode_min)?; + index_impl.bound_key(table_codec, params, value, true, encode_max)?; let iter = params.tx.range( Bound::Included(encode_min.as_slice()), @@ -1678,7 +1757,11 @@ pub struct TupleIter<'a, T: Transaction + 'a> { } impl<'a, T: Transaction + 'a> Iter for TupleIter<'a, T> { - fn next_tuple_into(&mut self, tuple: &mut Tuple) -> Result { + fn next_tuple_into( + &mut self, + _: &mut TableCodec, + tuple: &mut Tuple, + ) -> Result { while self.bounds.consume_offset() { if self.iter.try_next()?.is_none() { return Ok(false); @@ -1802,7 +1885,11 @@ pub enum IndexIterState<'a, T: Transaction + 'a> { } impl Iter for IndexIter<'_, T> { - fn next_tuple_into(&mut self, tuple: &mut Tuple) -> Result { + fn next_tuple_into( + &mut self, + table_codec: &mut TableCodec, + tuple: &mut Tuple, + ) -> Result { if self.bounds.limit_reached() { self.state = IndexIterState::Over; @@ -1819,8 +1906,9 @@ impl Iter for IndexIter<'_, T> { match binary { Range::Scope { min, max } => { let table_name = &self.params.table_name; - let index_meta = &self.params.index_meta; + let index_meta = self.params.index_meta(); let encode_min = encode_bound( + table_codec, min, false, &mut self.encode_min_buffer, @@ -1828,6 +1916,7 @@ impl Iter for IndexIter<'_, T> { &self.inner, )?; let encode_max = encode_bound( + table_codec, max, true, &mut self.encode_max_buffer, @@ -1835,7 +1924,6 @@ impl Iter for IndexIter<'_, T> { &self.inner, )?; - let table_codec = unsafe { &*self.params.table_codec() }; let tx = self.params.tx; let open_iter = move |bound_min: &[u8], bound_max: &[u8]| { tx.range( @@ -1856,6 +1944,7 @@ impl Iter for IndexIter<'_, T> { } Range::Eq(val) => { match self.inner.eq_to_res( + table_codec, tuple, val, &self.params, @@ -1884,8 +1973,13 @@ impl Iter for IndexIter<'_, T> { continue; } self.bounds.consume_limit(); - self.inner - .index_lookup_into(tuple, key, value, &self.params)?; + self.inner.index_lookup_into( + table_codec, + tuple, + key, + value, + &self.params, + )?; return Ok(true); } @@ -1902,7 +1996,11 @@ pub trait InnerIter { } pub trait Iter { - fn next_tuple_into(&mut self, tuple: &mut Tuple) -> Result; + fn next_tuple_into( + &mut self, + table_codec: &mut TableCodec, + tuple: &mut Tuple, + ) -> Result; } pub struct TableIter<'a, T: Transaction + 'a> { @@ -1923,6 +2021,7 @@ pub struct ViewIter<'a, T: Transaction + 'a> { iter: T::IterType<'a>, transaction: &'a T, table_cache: &'a TableCache, + table_arena: &'a TableArenaCell, scala_functions: &'a ScalaFunctions, table_functions: &'a TableFunctions, } @@ -1938,6 +2037,7 @@ impl ViewIter<'_, T> { (self.transaction, self.table_cache), self.scala_functions, self.table_functions, + self.table_arena.borrow_mut(), )?)) } } @@ -1945,7 +2045,8 @@ impl ViewIter<'_, T> { #[cfg(test)] pub(crate) fn next_tuple_for_test(iter: &mut I) -> Result, DatabaseError> { let mut tuple = Tuple::default(); - if iter.next_tuple_into(&mut tuple)? { + let mut table_codec = TableCodec::default(); + if iter.next_tuple_into(&mut table_codec, &mut tuple)? { Ok(Some(tuple)) } else { Ok(None) @@ -1962,40 +2063,22 @@ mod test { use crate::db::test::build_table; use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; + use crate::planner::{PlanArena, TableArenaCell}; use crate::storage::rocksdb::{RocksStorage, RocksTransaction}; use crate::storage::table_codec::TableCodec; use crate::storage::{ - IndexIter, InnerIter, StatisticsMetaCache, Storage, TableCache, Transaction, - TransactionIsolationLevel, + IndexIter, InnerIter, Storage, TableCache, Transaction, TransactionIsolationLevel, }; - use crate::types::index::{Index, IndexMeta, IndexType}; + use crate::types::index::{Index, IndexType}; use crate::types::tuple::Tuple; use crate::types::value::DataValue; - use crate::types::{ColumnId, LogicalType}; - use crate::utils::lru::SharedLruCache; + use crate::types::LogicalType; use std::collections::Bound; - use std::hash::RandomState; use std::sync::Arc; use tempfile::TempDir; - fn full_columns() -> Vec { - vec![ - ColumnRef::from(ColumnCatalog::new( - "c1".to_string(), - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( - "c2".to_string(), - false, - ColumnDesc::new(LogicalType::Boolean, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( - "c3".to_string(), - false, - ColumnDesc::new(LogicalType::Integer, None, false, None).unwrap(), - )), - ] + fn full_columns(table_cache: &TableCache) -> Vec { + table_cache.get("t1").unwrap().columns().copied().collect() } fn build_tuples() -> Vec { @@ -2032,12 +2115,16 @@ mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; let mut transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let mut table_cache = crate::storage::TableCache::default(); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); - build_table(&table_cache, &mut transaction)?; + build_table(&mut table_cache, &mut transaction, &mut plan_arena)?; + let plan_arena = PlanArena::new(&table_arena); let fn_assert = |transaction: &mut RocksTransaction, - table_cache: &TableCache| + table_cache: &TableCache, + plan_arena: &PlanArena| -> Result<(), DatabaseError> { let table = transaction .table(table_cache, "t1".to_string().into())? @@ -2049,7 +2136,7 @@ mod test { assert_eq!(table.name.as_ref(), "t1"); assert_eq!(table.indexes.len(), 1); - let primary_key_index_meta = &table.indexes[0]; + let primary_key_index_meta = plan_arena.index(table.indexes[0]); assert_eq!(primary_key_index_meta.id, 0); assert_eq!(primary_key_index_meta.column_ids, vec![c1_column_id]); assert_eq!(primary_key_index_meta.table_name, "t1".to_string().into()); @@ -2061,7 +2148,7 @@ mod test { ); let mut column_iter = table.columns(); - let c1_column = column_iter.next().unwrap(); + let c1_column = plan_arena.column(*column_iter.next().unwrap()); assert!(!c1_column.nullable()); assert_eq!( c1_column.summary(), @@ -2079,7 +2166,7 @@ mod test { &ColumnDesc::new(LogicalType::Integer, Some(0), false, None)? ); - let c2_column = column_iter.next().unwrap(); + let c2_column = plan_arena.column(*column_iter.next().unwrap()); assert!(!c2_column.nullable()); assert_eq!( c2_column.summary(), @@ -2097,7 +2184,7 @@ mod test { &ColumnDesc::new(LogicalType::Boolean, None, false, None)? ); - let c3_column = column_iter.next().unwrap(); + let c3_column = plan_arena.column(*column_iter.next().unwrap()); assert!(!c3_column.nullable()); assert_eq!( c3_column.summary(), @@ -2117,28 +2204,37 @@ mod test { Ok(()) }; - fn_assert(&mut transaction, &table_cache)?; - fn_assert( - &mut transaction, - &Arc::new(SharedLruCache::new(4, 1, RandomState::new())?), - )?; + fn_assert(&mut transaction, &table_cache, &plan_arena)?; + let mut reloaded_table_cache = crate::storage::TableCache::default(); + let mut table_codec = TableCodec::default(); + let table_name = "t1".to_string().into(); + let table = transaction + .load_table(&mut table_codec, table_arena.borrow_mut(), table_name)? + .ok_or(DatabaseError::TableNotFound)?; + reloaded_table_cache.insert("t1".to_string().into(), table); + let plan_arena = PlanArena::new(&table_arena); + fn_assert(&mut transaction, &reloaded_table_cache, &plan_arena)?; Ok(()) } #[test] fn test_tuple_append_delete() -> Result<(), DatabaseError> { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; let mut transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let mut table_cache = crate::storage::TableCache::default(); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); - build_table(&table_cache, &mut transaction)?; + build_table(&mut table_cache, &mut transaction, &mut plan_arena)?; + let plan_arena = PlanArena::new(&table_arena); let tuples = build_tuples(); for tuple in tuples.iter().cloned() { transaction.append_tuple( + &mut table_codec, "t1", tuple, &[ @@ -2151,10 +2247,12 @@ mod test { } { let mut tuple_iter = transaction.read( + &mut table_codec, + &plan_arena, &table_cache, "t1".to_string().into(), (None, None), - full_columns(), + full_columns(&table_cache), true, )?; @@ -2184,13 +2282,15 @@ mod test { assert!(iter.try_next()?.is_none()); } - transaction.remove_tuple("t1", &tuples[1].values[0])?; + transaction.remove_tuple(&mut table_codec, "t1", &tuples[1].values[0])?; { let mut tuple_iter = transaction.read( + &mut table_codec, + &plan_arena, &table_cache, "t1".to_string().into(), (None, None), - full_columns(), + full_columns(&table_cache), true, )?; @@ -2219,13 +2319,16 @@ mod test { #[test] fn test_add_index_meta() -> Result<(), DatabaseError> { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; let mut transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let mut table_cache = crate::storage::TableCache::default(); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); - build_table(&table_cache, &mut transaction)?; + build_table(&mut table_cache, &mut transaction, &mut plan_arena)?; + let mut plan_arena = PlanArena::new(&table_arena); let (c2_column_id, c3_column_id) = { let t1_table = transaction .table(&table_cache, "t1".to_string().into())? @@ -2237,20 +2340,27 @@ mod test { ) }; - let _ = transaction.add_index_meta( - &table_cache, + let (table, _) = transaction.add_index_meta( + &mut table_codec, + &mut plan_arena, &"t1".to_string().into(), "i1".to_string(), vec![c3_column_id], IndexType::Normal, )?; - let _ = transaction.add_index_meta( - &table_cache, + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); + let mut plan_arena = PlanArena::new(&table_arena); + let (table, _) = transaction.add_index_meta( + &mut table_codec, + &mut plan_arena, &"t1".to_string().into(), "i2".to_string(), vec![c3_column_id, c2_column_id], IndexType::Composite, )?; + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); let fn_assert = |transaction: &mut RocksTransaction, table_cache: &TableCache| @@ -2258,8 +2368,9 @@ mod test { let table = transaction .table(table_cache, "t1".to_string().into())? .unwrap(); + let plan_arena = PlanArena::new(&table_arena); - let i1_meta = table.indexes[1].clone(); + let i1_meta = plan_arena.index(table.indexes[1]); assert_eq!(i1_meta.id, 1); assert_eq!(i1_meta.column_ids, vec![c3_column_id]); assert_eq!(i1_meta.table_name, "t1".to_string().into()); @@ -2267,7 +2378,7 @@ mod test { assert_eq!(i1_meta.name, "i1".to_string()); assert_eq!(i1_meta.ty, IndexType::Normal); - let i2_meta = table.indexes[2].clone(); + let i2_meta = plan_arena.index(table.indexes[2]); assert_eq!(i2_meta.id, 2); assert_eq!(i2_meta.column_ids, vec![c3_column_id, c2_column_id]); assert_eq!(i2_meta.table_name, "t1".to_string().into()); @@ -2278,10 +2389,14 @@ mod test { Ok(()) }; fn_assert(&mut transaction, &table_cache)?; - fn_assert( - &mut transaction, - &Arc::new(SharedLruCache::new(4, 1, RandomState::new())?), - )?; + let mut reloaded_table_cache = crate::storage::TableCache::default(); + let table_name = "t1".to_string().into(); + let table = transaction + .load_table(&mut table_codec, table_arena.borrow_mut(), table_name)? + .ok_or(DatabaseError::TableNotFound)?; + reloaded_table_cache.insert("t1".to_string().into(), table); + let mut plan_arena = PlanArena::new(&table_arena); + fn_assert(&mut transaction, &reloaded_table_cache)?; { let mut iter = table_codec.with_index_meta_bound("t1", |min, max| { transaction.range(Bound::Included(min), Bound::Included(max)) @@ -2295,10 +2410,9 @@ mod test { dbg!(value); assert!(iter.try_next()?.is_none()); } - let meta_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); match transaction.drop_index( - &table_cache, - &meta_cache, + &mut table_codec, + &mut plan_arena, "t1".to_string().into(), "pk_index", false, @@ -2306,18 +2420,21 @@ mod test { Err(DatabaseError::InvalidIndex) => (), _ => unreachable!(), } - transaction.drop_index( - &table_cache, - &meta_cache, + if let Some((table, _)) = transaction.drop_index( + &mut table_codec, + &mut plan_arena, "t1".to_string().into(), "i1", false, - )?; + )? { + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); + } { let table = transaction .table(&table_cache, "t1".to_string().into())? .unwrap(); - let i2_meta = table.indexes[1].clone(); + let i2_meta = plan_arena.index(table.indexes[1]); assert_eq!(i2_meta.id, 2); assert_eq!(i2_meta.column_ids, vec![c3_column_id, c2_column_id]); assert_eq!(i2_meta.table_name, "t1".to_string().into()); @@ -2343,23 +2460,26 @@ mod test { fn test_index_insert_delete() -> Result<(), DatabaseError> { fn build_index_iter<'a>( transaction: &'a RocksTransaction<'a>, - table_cache: &'a Arc, - index_column_id: ColumnId, + table_cache: &'a TableCache, + plan_arena: &'a PlanArena<'a>, ) -> Result>, DatabaseError> { + let table_name: crate::catalog::TableName = "t1".to_string().into(); + let index_meta = table_cache + .get(&table_name) + .and_then(|table| { + table + .indexes() + .copied() + .find(|index| plan_arena.index(*index).id == 1) + }) + .ok_or(DatabaseError::InvalidIndex)?; transaction.read_by_index( table_cache, + plan_arena, "t1".to_string().into(), (None, None), - full_columns(), - Arc::new(IndexMeta { - id: 1, - column_ids: vec![index_column_id], - table_name: "t1".to_string().into(), - pk_ty: LogicalType::Integer, - value_ty: LogicalType::Integer, - name: "i1".to_string(), - ty: IndexType::Normal, - }), + full_columns(table_cache), + index_meta, vec![Range::Scope { min: Bound::Unbounded, max: Bound::Unbounded, @@ -2370,25 +2490,32 @@ mod test { ) } - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; let mut transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let mut table_cache = crate::storage::TableCache::default(); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); - build_table(&table_cache, &mut transaction)?; + build_table(&mut table_cache, &mut transaction, &mut plan_arena)?; + let mut plan_arena = PlanArena::new(&table_arena); let t1_table = transaction .table(&table_cache, "t1".to_string().into())? .unwrap(); let c3_column_id = *t1_table.get_column_id_by_name("c3").unwrap(); - let _ = transaction.add_index_meta( - &table_cache, + let (table, _) = transaction.add_index_meta( + &mut table_codec, + &mut plan_arena, &"t1".to_string().into(), "i1".to_string(), vec![c3_column_id], IndexType::Normal, )?; + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); + let plan_arena = PlanArena::new(&table_arena); let tuples = build_tuples(); let indexes = [ @@ -2406,10 +2533,11 @@ mod test { ), ]; for (tuple_id, index) in indexes.iter().cloned() { - transaction.add_index("t1", index, &tuple_id)?; + transaction.add_index(&mut table_codec, "t1", index, &tuple_id)?; } for tuple in tuples.iter().cloned() { transaction.append_tuple( + &mut table_codec, "t1", tuple, &[ @@ -2421,7 +2549,7 @@ mod test { )?; } { - let mut index_iter = build_index_iter(&transaction, &table_cache, c3_column_id)?; + let mut index_iter = build_index_iter(&transaction, &table_cache, &plan_arena)?; assert_eq!( super::next_tuple_for_test(&mut index_iter)?.unwrap(), @@ -2448,9 +2576,9 @@ mod test { dbg!(value); assert!(iter.try_next()?.is_none()); } - transaction.del_index("t1", &indexes[0].1, &indexes[0].0)?; + transaction.del_index(&mut table_codec, "t1", &indexes[0].1, &indexes[0].0)?; - let mut index_iter = build_index_iter(&transaction, &table_cache, c3_column_id)?; + let mut index_iter = build_index_iter(&transaction, &table_cache, &plan_arena)?; assert_eq!( super::next_tuple_for_test(&mut index_iter)?.unwrap(), @@ -2476,10 +2604,10 @@ mod test { #[test] fn test_reader_transaction_can_mix_index_and_heap_views() -> Result<(), DatabaseError> { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); + let mut table_cache = crate::storage::TableCache::default(); let serializers = [ LogicalType::Integer.serializable(), LogicalType::Boolean.serializable(), @@ -2502,33 +2630,45 @@ mod test { DataValue::Int32(1), ], ); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); let index_id = { let mut setup_tx = storage.transaction()?; - build_table(&table_cache, &mut setup_tx)?; + build_table(&mut table_cache, &mut setup_tx, &mut plan_arena)?; + let mut plan_arena = PlanArena::new(&table_arena); let table = setup_tx .table(&table_cache, "t1".to_string().into())? .unwrap(); let c3_column_id = *table.get_column_id_by_name("c3").unwrap(); - let index_id = setup_tx.add_index_meta( - &table_cache, + let (table, index_id) = setup_tx.add_index_meta( + &mut table_codec, + &mut plan_arena, &"t1".to_string().into(), "i1".to_string(), vec![c3_column_id], IndexType::Normal, )?; + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); setup_tx.add_index( + &mut table_codec, "t1", Index::new(index_id, &initial_tuple.values[2], IndexType::Normal), initial_tuple.pk.as_ref().unwrap(), )?; - setup_tx.append_tuple("t1", initial_tuple.clone(), &serializers, false)?; + setup_tx.append_tuple( + &mut table_codec, + "t1", + initial_tuple.clone(), + &serializers, + false, + )?; setup_tx.commit()?; index_id }; - let reader_tx = storage.transaction_with_isolation(TransactionIsolationLevel::ReadCommitted)?; let tuple_id = { @@ -2540,7 +2680,7 @@ mod test { TableCodec::decode_index(value)? }; - let before_update = table_codec.with_tuple_key("t1", &tuple_id, |key| { + let before_update = table_codec.with_tuple("t1", &tuple_id, None, |key, _| { let bytes = reader_tx.get_borrowed(key)?.expect("tuple should exist"); let mut tuple = Tuple::default(); @@ -2557,19 +2697,27 @@ mod test { let mut writer_tx = storage.transaction()?; writer_tx.del_index( + &mut table_codec, "t1", &Index::new(index_id, &initial_tuple.values[2], IndexType::Normal), initial_tuple.pk.as_ref().unwrap(), )?; writer_tx.add_index( + &mut table_codec, "t1", Index::new(index_id, &updated_tuple.values[2], IndexType::Normal), updated_tuple.pk.as_ref().unwrap(), )?; - writer_tx.append_tuple("t1", updated_tuple.clone(), &serializers, true)?; + writer_tx.append_tuple( + &mut table_codec, + "t1", + updated_tuple.clone(), + &serializers, + true, + )?; writer_tx.commit()?; - let after_update = table_codec.with_tuple_key("t1", &tuple_id, |key| { + let after_update = table_codec.with_tuple("t1", &tuple_id, None, |key, _| { let bytes = reader_tx.get_borrowed(key)?.expect("tuple should exist"); let mut tuple = Tuple::default(); @@ -2595,10 +2743,13 @@ mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; let mut transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let meta_cache = StatisticsMetaCache::new(4, 1, RandomState::new())?; + let mut table_cache = crate::storage::TableCache::default(); + let mut table_codec = TableCodec::default(); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); - build_table(&table_cache, &mut transaction)?; + build_table(&mut table_cache, &mut transaction, &mut plan_arena)?; + let mut plan_arena = PlanArena::new(&table_arena); let table_name: TableName = "t1".to_string().into(); let new_column = ColumnCatalog::new( @@ -2606,15 +2757,37 @@ mod test { true, ColumnDesc::new(LogicalType::Integer, None, false, None)?, ); - let new_column_id = - transaction.add_column(&table_cache, &table_name, &new_column, false)?; + let (table, new_column_id) = transaction.add_column( + &mut table_codec, + &mut plan_arena, + &table_name, + &new_column, + false, + )?; + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); + let mut plan_arena = PlanArena::new(&table_arena); { assert!(transaction - .add_column(&table_cache, &table_name, &new_column, false,) + .add_column( + &mut table_codec, + &mut plan_arena, + &table_name, + &new_column, + false, + ) .is_err()); assert_eq!( new_column_id, - transaction.add_column(&table_cache, &table_name, &new_column, true,)? + transaction + .add_column( + &mut table_codec, + &mut plan_arena, + &table_name, + &new_column, + true, + )? + .1 ); } { @@ -2633,12 +2806,13 @@ mod test { table_name: table_name.clone(), is_temp: false, }; - assert_eq!( - table.get_column_by_name("c4"), - Some(&ColumnRef::from(new_column)) - ); + let column = table.get_column_by_name("c4").unwrap(); + assert_eq!(table_arena.borrow().column(column), &new_column); } - transaction.drop_column(&table_cache, &meta_cache, &table_name, "c4")?; + let table = + transaction.drop_column(&mut table_codec, &mut plan_arena, &table_name, "c4")?; + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); { let table = transaction .table(&table_cache, table_name.clone())? @@ -2657,21 +2831,29 @@ mod test { let table_functions = Default::default(); let view_name: TableName = "v1".to_string().into(); + let mut plan_arena = PlanArena::new(&table_state.table_arena); + let mut plan = table_state.plan_with_arena( + "select c1, c3 from t1 inner join t2 on c1 = c3 and c1 > 1", + &mut plan_arena, + )?; + let schema = plan.output_schema(&mut plan_arena).clone(); let view = View { name: view_name.clone(), - plan: Box::new( - table_state.plan("select c1, c3 from t1 inner join t2 on c1 = c3 and c1 > 1")?, - ), + plan: Box::new(plan), + schema, }; let mut transaction = table_state.storage.transaction()?; - transaction.create_view(&table_state.view_cache, view.clone(), true)?; + let mut view_cache = table_state.view_cache.clone(); + let mut table_codec = TableCodec::default(); + let view = transaction.create_view(&mut table_codec, &plan_arena, view.clone(), true)?; + view_cache.insert(view.name.clone(), view.clone()); assert_eq!( &view, transaction .view( &table_state.table_cache, - &table_state.view_cache, + &view_cache, &scala_functions, &table_functions, view_name.clone(), @@ -2682,8 +2864,8 @@ mod test { &view, transaction .view( - &Arc::new(SharedLruCache::new(4, 1, RandomState::new())?), - &table_state.view_cache, + &crate::storage::TableCache::default(), + &view_cache, &scala_functions, &table_functions, view_name.clone(), @@ -2691,12 +2873,16 @@ mod test { .unwrap() ); - transaction.drop_view(&table_state.view_cache, view_name.clone(), false)?; - transaction.drop_view(&table_state.view_cache, view_name.clone(), true)?; + if transaction.drop_view(&mut table_codec, view_name.clone(), false)? { + view_cache.remove(&view_name); + } + if transaction.drop_view(&mut table_codec, view_name.clone(), true)? { + view_cache.remove(&view_name); + } assert!(transaction .view( &table_state.table_cache, - &table_state.view_cache, + &view_cache, &scala_functions, &table_functions, view_name, diff --git a/src/storage/rocksdb.rs b/src/storage/rocksdb.rs index bcf24dde..f31b9f68 100644 --- a/src/storage/rocksdb.rs +++ b/src/storage/rocksdb.rs @@ -13,7 +13,6 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::storage::table_codec::TableCodec; use crate::storage::{ CheckpointableStorage, InnerIter, Storage, Transaction, TransactionIsolationLevel, }; @@ -447,7 +446,6 @@ impl Storage for OptimisticRocksStorage { isolation, current_snapshot: matches!(isolation, TransactionIsolationLevel::RepeatableRead) .then(|| self.inner.snapshot()), - table_codec: Default::default(), }) } @@ -495,7 +493,6 @@ impl Storage for RocksStorage { isolation, current_snapshot: matches!(isolation, TransactionIsolationLevel::RepeatableRead) .then(|| self.inner.snapshot()), - table_codec: Default::default(), }) } @@ -551,7 +548,7 @@ impl CheckpointableStorage for RocksStorage { return Err(err); } - return Ok(()); + Ok(()) } #[cfg(not(feature = "unsafe_txdb_checkpoint"))] @@ -566,7 +563,6 @@ pub struct OptimisticRocksTransaction<'db> { tx: rocksdb::Transaction<'db, OptimisticTransactionDB>, isolation: TransactionIsolationLevel, current_snapshot: Option>, - table_codec: TableCodec, } pub struct RocksTransaction<'db> { @@ -574,7 +570,6 @@ pub struct RocksTransaction<'db> { tx: rocksdb::Transaction<'db, TransactionDB>, isolation: TransactionIsolationLevel, current_snapshot: Option>>, - table_codec: TableCodec, } fn build_read_options( @@ -600,12 +595,6 @@ macro_rules! impl_transaction { = $iter<'storage, 'iter> where Self: 'iter; - - #[inline] - fn table_codec(&self) -> *const TableCodec { - &self.table_codec - } - fn begin_statement_scope(&mut self) -> Result<(), DatabaseError> { if self.isolation == TransactionIsolationLevel::ReadCommitted { self.current_snapshot = Some(self.db.snapshot()); @@ -770,24 +759,23 @@ fn next<'a, D: rocksdb::DBAccess>( #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef, TableName}; - use crate::db::DataBaseBuilder; + use crate::catalog::{ColumnCatalog, ColumnDesc, TableName}; + use crate::db::{CatalogKind, DataBaseBuilder}; use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; + use crate::planner::{PlanArena, TableArenaCell}; use crate::storage::rocksdb::RocksStorage; + use crate::storage::table_codec::TableCodec; use crate::storage::{ IndexImplEnum, IndexImplParams, IndexIter, IndexIterState, InnerIter, IterBounds, PrimaryKeyIndexImpl, Storage, Transaction, }; - use crate::types::index::{IndexMeta, IndexType}; + use crate::types::index::IndexType; use crate::types::tuple::Tuple; use crate::types::value::DataValue; use crate::types::LogicalType; - use crate::utils::lru::SharedLruCache; use itertools::Itertools; use std::collections::Bound; - use std::hash::RandomState; - use std::sync::Arc; use tempfile::TempDir; #[test] @@ -802,12 +790,11 @@ mod test { #[test] fn test_collect_rocksdb_metrics_snapshot() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()) + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()) .storage_statistics(true) .build_rocksdb()?; - kite_sql - .run("create table t_metrics (a int primary key, b int)")? - .done()?; + kite_sql.ddl("create table t_metrics (a int primary key, b int)")?; + kite_sql.load(CatalogKind::Table("t_metrics".to_string().into()))?; kite_sql .run("insert into t_metrics values (1, 10), (2, 20), (3, 30)")? .done()?; @@ -854,36 +841,41 @@ mod test { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); let storage = RocksStorage::new(temp_dir.path())?; let mut transaction = storage.transaction()?; - let table_cache = Arc::new(SharedLruCache::new(4, 1, RandomState::new())?); - let columns = Arc::new(vec![ - ColumnRef::from(ColumnCatalog::new( + let mut table_cache = crate::storage::TableCache::default(); + let mut table_codec = TableCodec::default(); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); + let source_columns = vec![ + ColumnCatalog::new( "c1".to_string(), false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c2".to_string(), false, ColumnDesc::new(LogicalType::Boolean, None, false, None).unwrap(), - )), - ]); - - let source_columns = columns - .iter() - .map(|col_ref| ColumnCatalog::clone(col_ref)) - .collect_vec(); - let _ = transaction.create_table( - &table_cache, + ), + ]; + if let Some(table) = transaction.create_table( + &mut table_codec, + &mut plan_arena, "test".to_string().into(), source_columns, false, - )?; + )? { + let table = table.transplant_to_table_arena(&plan_arena)?; + table_cache.insert(table.name().clone(), table); + } + let plan_arena = PlanArena::new(&table_arena); + let table_name: TableName = "test".to_string().into(); let table_catalog = transaction.table(&table_cache, "test".to_string().into())?; assert!(table_catalog.is_some()); assert!(table_catalog.unwrap().get_column_id_by_name("c1").is_some()); transaction.append_tuple( + &mut table_codec, "test", Tuple::new( Some(DataValue::Int32(1)), @@ -896,6 +888,7 @@ mod test { false, )?; transaction.append_tuple( + &mut table_codec, "test", Tuple::new( Some(DataValue::Int32(2)), @@ -908,9 +901,18 @@ mod test { false, )?; - let read_columns = vec![columns[0].clone()]; + let read_column = table_cache + .get(&table_name) + .unwrap() + .columns() + .next() + .copied() + .unwrap(); + let read_columns = vec![read_column]; let mut iter = transaction.read( + &mut table_codec, + &plan_arena, &table_cache, "test".to_string().into(), (Some(1), Some(1)), @@ -930,11 +932,10 @@ mod test { #[test] fn test_index_iter_pk() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t1 (a int primary key)")? - .done()?; + kite_sql.ddl("create table t1 (a int primary key)")?; + kite_sql.load(CatalogKind::Table("t1".to_string().into()))?; kite_sql .run("insert into t1 (a) values (0), (1), (2), (3), (4)")? .done()?; @@ -945,7 +946,17 @@ mod test { .table(kite_sql.state.table_cache(), table_name.clone())? .unwrap() .clone(); - let a_column_id = table.get_column_id_by_name("a").unwrap(); + let plan_arena = PlanArena::new(kite_sql.state.table_arena()); + let index_meta = table + .indexes() + .copied() + .find(|index| { + matches!( + plan_arena.index(*index).ty, + IndexType::PrimaryKey { is_multiple: false } + ) + }) + .ok_or(DatabaseError::InvalidIndex)?; let tuple_ids = vec![ DataValue::Int32(0), DataValue::Int32(2), @@ -954,21 +965,14 @@ mod test { ]; let deserializers = table .columns() - .map(|column| column.datatype().serializable()) + .map(|column| plan_arena.column(*column).datatype().serializable()) .collect_vec(); let mut iter: IndexIter<'_, _> = IndexIter { bounds: IterBounds::new(0, None), params: IndexImplParams { - index_meta: Arc::new(IndexMeta { - id: 0, - column_ids: vec![*a_column_id], - table_name, - pk_ty: LogicalType::Integer, - value_ty: LogicalType::Integer, - name: "pk_a".to_string(), - ty: IndexType::PrimaryKey { is_multiple: false }, - }), - table_name: &table.name, + index_meta, + meta_arena: plan_arena.table_arena_cell().borrow(), + table_name: table.name.clone(), deserializers, total_len: table.columns_len(), tx: &transaction, @@ -1002,10 +1006,9 @@ mod test { #[test] fn test_read_by_index() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t1 (a int primary key, b int unique)")? - .done()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + kite_sql.ddl("create table t1 (a int primary key, b int unique)")?; + kite_sql.load(CatalogKind::Table("t1".to_string().into()))?; kite_sql .run("insert into t1 (a, b) values (0, 0), (1, 1), (2, 2), (3, 4)")? .done()?; @@ -1015,14 +1018,16 @@ mod test { .table(kite_sql.state.table_cache(), "t1".to_string().into())? .unwrap() .clone(); + let plan_arena = PlanArena::new(kite_sql.state.table_arena()); { let mut iter = transaction .read_by_index( kite_sql.state.table_cache(), + &plan_arena, "t1".to_string().into(), (Some(0), Some(1)), table.columns().cloned().collect(), - table.indexes[0].clone(), + table.indexes[0], vec![Range::Scope { min: Bound::Excluded(DataValue::Int32(0)), max: Bound::Unbounded, @@ -1046,10 +1051,11 @@ mod test { let mut iter = transaction .read_by_index( kite_sql.state.table_cache(), + &plan_arena, "t1".to_string().into(), (Some(0), Some(1)), columns, - table.indexes[0].clone(), + table.indexes[0], vec![Range::Scope { min: Bound::Excluded(DataValue::Int32(3)), max: Bound::Unbounded, @@ -1072,66 +1078,65 @@ mod test { #[test] fn test_read_by_index_cover() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("unable to create temporary working directory"); - let kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; - kite_sql - .run("create table t1 (a int primary key, b int unique)")? - .done()?; + let mut kite_sql = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + kite_sql.ddl("create table t1 (a int primary key, b int unique)")?; + kite_sql.load(CatalogKind::Table("t1".to_string().into()))?; kite_sql .run("insert into t1 (a, b) values (0, 0), (1, 1), (2, 2), (3, 4)")? .done()?; - kite_sql.run("create index idx_b_a on t1(b, a)")?.done()?; + kite_sql.ddl("create index idx_b_a on t1(b, a)")?; + kite_sql.load(CatalogKind::Table("t1".to_string().into()))?; let mut transaction = kite_sql.storage.transaction().unwrap(); let table = transaction .table(kite_sql.state.table_cache(), "t1".to_string().into())? .unwrap() .clone(); + let plan_arena = PlanArena::new(kite_sql.state.table_arena()); let columns_vec: Vec<_> = table.columns().cloned().collect(); - let a_cover_column = columns_vec + let a_cover_column = *columns_vec .iter() - .find(|column| column.name() == "a") - .unwrap() - .clone(); - let b_cover_column = columns_vec + .find(|column| plan_arena.column(**column).name() == "a") + .unwrap(); + let b_cover_column = *columns_vec .iter() - .find(|column| column.name() == "b") - .unwrap() - .clone(); - let unique_index = table + .find(|column| plan_arena.column(**column).name() == "b") + .unwrap(); + let unique_index = *table .indexes .iter() - .find(|index| matches!(index.ty, IndexType::Unique)) - .unwrap() - .clone(); + .find(|index| matches!(plan_arena.index(**index).ty, IndexType::Unique)) + .unwrap(); let b_column = table .columns() .cloned() - .find(|column| column.name() == "b") + .find(|column| plan_arena.column(*column).name() == "b") .unwrap(); - let columns = vec![b_column.clone()]; - let covered_deserializers = vec![b_column.datatype().serializable()]; + let columns = vec![b_column]; + let covered_deserializers = vec![plan_arena.column(b_column).datatype().serializable()]; // ensure cover mapping can reorder index values to match scan order - let composite_index = table + let composite_index = *table .indexes .iter() - .find(|index| index.name == "idx_b_a") - .unwrap() - .clone(); - let reordered_columns = vec![a_cover_column.clone(), b_cover_column.clone()]; + .find(|index| plan_arena.index(**index).name == "idx_b_a") + .unwrap(); + let reordered_columns = vec![a_cover_column, b_cover_column]; let reordered_deserializers = vec![ - a_cover_column.datatype().serializable(), - b_cover_column.datatype().serializable(), + plan_arena.column(a_cover_column).datatype().serializable(), + plan_arena.column(b_cover_column).datatype().serializable(), ]; - let a_id = a_cover_column.id().unwrap(); - let b_id = b_cover_column.id().unwrap(); + let a_id = plan_arena.column(a_cover_column).id().unwrap(); + let b_id = plan_arena.column(b_cover_column).id().unwrap(); let cover_mapping = vec![ - composite_index + plan_arena + .index(composite_index) .column_ids .iter() .position(|id| id == &a_id) .unwrap(), - composite_index + plan_arena + .index(composite_index) .column_ids .iter() .position(|id| id == &b_id) @@ -1140,6 +1145,7 @@ mod test { let mut iter = transaction.read_by_index( kite_sql.state.table_cache(), + &plan_arena, "t1".to_string().into(), (None, None), reordered_columns, @@ -1161,10 +1167,12 @@ mod test { let target_pk = DataValue::Int32(3); let covered_value = DataValue::Int32(4); - transaction.remove_tuple("t1", &target_pk)?; + let mut table_codec = TableCodec::default(); + transaction.remove_tuple(&mut table_codec, "t1", &target_pk)?; let mut iter = transaction.read_by_index( kite_sql.state.table_cache(), + &plan_arena, "t1".to_string().into(), (Some(0), Some(1)), columns, @@ -1185,16 +1193,16 @@ mod test { assert_eq!(tuples[0].values, vec![covered_value]); // primary key index should ignore covered-deserializer hint and still return rows - let pk_index = table + let pk_index = *table .indexes .iter() - .find(|index| index.name == "pk_index") - .unwrap() - .clone(); - let pk_columns = vec![a_cover_column.clone()]; - let pk_deserializers = vec![a_cover_column.datatype().serializable()]; + .find(|index| plan_arena.index(**index).name == "pk_index") + .unwrap(); + let pk_columns = vec![a_cover_column]; + let pk_deserializers = vec![plan_arena.column(a_cover_column).datatype().serializable()]; let mut iter = transaction.read_by_index( kite_sql.state.table_cache(), + &plan_arena, "t1".to_string().into(), (None, None), pk_columns, diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 1a78fa4a..91a2de7a 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -13,12 +13,13 @@ // limitations under the License. use crate::catalog::view::View; -use crate::catalog::{ColumnRef, ColumnRelation, TableMeta}; +use crate::catalog::{ColumnCatalog, ColumnRelation, TableMeta}; use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; use crate::optimizer::core::cm_sketch::{CountMinSketchMeta, CountMinSketchPage}; use crate::optimizer::core::histogram::Bucket; use crate::optimizer::core::statistics_meta::StatisticsMetaRoot; +use crate::planner::{MetaArena, TableArena}; use crate::serdes::{ReferenceDecodeContext, ReferenceSerialization, ReferenceTables}; use crate::storage::{TableCache, Transaction}; use crate::types::index::{Index, IndexId, IndexMeta, IndexType, INDEX_ID_LEN}; @@ -26,9 +27,8 @@ use crate::types::serialize::TupleValueSerializableImpl; use crate::types::tuple::{Tuple, TupleId}; use crate::types::value::{DataValue, TupleMappingRef}; use crate::types::LogicalType; -use bumpalo::Bump; use siphasher::sip::SipHasher; -use std::cell::RefCell; +use std::borrow::Borrow; use std::hash::{Hash, Hasher}; use std::io::{Cursor, Read, Seek, SeekFrom}; use std::sync::LazyLock; @@ -50,19 +50,28 @@ static HASH_BYTES: LazyLock> = LazyLock::new(|| b"Hash".to_vec()); pub type Bytes = Vec; pub type BumpBytes<'bump> = bumpalo::collections::Vec<'bump, u8>; +type TupleValueWriter<'a> = &'a mut dyn FnMut(&Tuple, &mut Bytes) -> Result<(), DatabaseError>; #[derive(Default)] pub struct TableCodec { - arena: Bump, - key_buffer: RefCell, + buffers: [Bytes; 2], + reference_tables: ReferenceTables, } -#[derive(Default)] -struct KeyBuffer { - lower: Bytes, - upper: Bytes, - cached_table_name: String, - cached_table_hash: [u8; TABLE_NAME_HASH_LEN], +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +enum CodecSlot { + S0, + S1, +} + +impl CodecSlot { + #[inline] + fn index(self) -> usize { + match self { + CodecSlot::S0 => 0, + CodecSlot::S1 => 1, + } + } } #[derive(Copy, Clone)] @@ -110,8 +119,16 @@ impl StatisticsCodecType { impl TableCodec { #[inline] - pub fn arena(&self) -> &Bump { - &self.arena + fn clear_buffers(&mut self) { + self.buffers[CodecSlot::S0.index()].clear(); + self.buffers[CodecSlot::S1.index()].clear(); + self.reference_tables.clear(); + } + + #[inline] + fn slots_mut(&mut self) -> (&mut Bytes, &mut Bytes, &mut ReferenceTables) { + let [s0, s1] = &mut self.buffers; + (s0, s1, &mut self.reference_tables) } fn hash_bytes(table_name: &str) -> [u8; 8] { @@ -166,7 +183,6 @@ impl TableCodec { #[inline] fn write_key_prefix(out: &mut Bytes, ty: CodecType, table_hash: [u8; TABLE_NAME_HASH_LEN]) { - out.clear(); match ty { CodecType::Column => { out.extend_from_slice(&table_hash); @@ -208,69 +224,70 @@ impl TableCodec { #[inline] fn write_global_bound_prefix(out: &mut Bytes, prefix: &[u8], bound: u8) { - out.clear(); out.extend_from_slice(prefix); out.push(bound); } - fn with_table_hash( - &self, + fn with_table_hash_buffers( + &mut self, table_name: &str, - f: impl FnOnce(&mut KeyBuffer, [u8; TABLE_NAME_HASH_LEN]) -> Result, + f: impl FnOnce( + &mut Bytes, + [u8; TABLE_NAME_HASH_LEN], + &mut Bytes, + &mut ReferenceTables, + ) -> Result, ) -> Result { - let mut key_buffer = self.key_buffer.borrow_mut(); - let table_hash = if key_buffer.cached_table_name != table_name { - key_buffer.cached_table_name.clear(); - key_buffer.cached_table_name.push_str(table_name); - key_buffer.cached_table_hash = Self::hash_bytes(table_name); - key_buffer.cached_table_hash - } else { - key_buffer.cached_table_hash - }; - - f(&mut key_buffer, table_hash) + let table_hash = Self::hash_bytes(table_name); + let (s0, s1, reference_tables) = self.slots_mut(); + f(s0, table_hash, s1, reference_tables) } /// Key: `{TableName}{TUPLE_TAG}{BOUND_MIN_TAG}{RowID}`. - pub fn with_tuple_key( - &self, + pub fn with_tuple( + &mut self, table_name: &str, tuple_id: &TupleId, - f: impl FnOnce(&[u8]) -> Result, + tuple_value: Option<(&Tuple, TupleValueWriter<'_>)>, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { Self::check_primary_key(tuple_id, 0)?; - self.with_tuple_key_unchecked(table_name, tuple_id, f) + self.with_tuple_unchecked(table_name, tuple_id, tuple_value, f) } #[inline] - pub(crate) fn with_tuple_key_unchecked( - &self, + pub(crate) fn with_tuple_unchecked( + &mut self, table_name: &str, tuple_id: &TupleId, - f: impl FnOnce(&[u8]) -> Result, + tuple_value: Option<(&Tuple, TupleValueWriter<'_>)>, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, value, _| { Self::write_key_prefix(lower, CodecType::Tuple, table_hash); lower.push(BOUND_MIN_TAG); tuple_id.memcomparable_encode(lower)?; - f(lower.as_slice()) + if let Some((tuple, serializers)) = tuple_value { + serializers(tuple, value)?; + } + + f(lower.as_slice(), value.as_slice()) }) } /// Range bounds covering all tuple keys for a table. pub fn with_tuple_bound( - &self, + &mut self, table_name: &str, f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, upper, _| { Self::write_key_prefix(lower, CodecType::Tuple, table_hash); lower.push(BOUND_MIN_TAG); - let upper = &mut key_buffer.upper; Self::write_key_prefix(upper, CodecType::Tuple, table_hash); upper.push(BOUND_MAX_TAG); @@ -279,34 +296,38 @@ impl TableCodec { } /// Key: `{TableName}{INDEX_META_TAG}{BOUND_MIN_TAG}{IndexID}`. - pub fn with_index_meta_key( - &self, + pub fn with_index_meta( + &mut self, table_name: &str, index_id: IndexId, - f: impl FnOnce(&[u8]) -> Result, + index_meta: Option<&IndexMeta>, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, value, refs| { Self::write_key_prefix(lower, CodecType::IndexMeta, table_hash); lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&index_id.to_le_bytes()); - f(lower.as_slice()) + if let Some(index_meta) = index_meta { + Self::encode_index_meta_value_into(index_meta, refs, value)?; + } + + f(lower.as_slice(), value.as_slice()) }) } /// Range bounds covering all index metadata for a table. pub fn with_index_meta_bound( - &self, + &mut self, table_name: &str, f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, upper, _| { Self::write_key_prefix(lower, CodecType::IndexMeta, table_hash); lower.push(BOUND_MIN_TAG); - let upper = &mut key_buffer.upper; Self::write_key_prefix(upper, CodecType::IndexMeta, table_hash); upper.push(BOUND_MAX_TAG); @@ -316,19 +337,18 @@ impl TableCodec { /// Range bounds covering a single secondary index. pub fn with_index_bound( - &self, + &mut self, table_name: &str, index_id: IndexId, f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, upper, _| { Self::write_key_prefix(lower, CodecType::Index, table_hash); lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&index_id.to_le_bytes()); lower.push(BOUND_MIN_TAG); - let upper = &mut key_buffer.upper; Self::write_key_prefix(upper, CodecType::Index, table_hash); upper.push(BOUND_MIN_TAG); upper.extend_from_slice(&index_id.to_le_bytes()); @@ -340,16 +360,15 @@ impl TableCodec { /// Range bounds covering all secondary indexes for a table. pub fn with_all_index_bound( - &self, + &mut self, table_name: &str, f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, upper, _| { Self::write_key_prefix(lower, CodecType::Index, table_hash); lower.push(BOUND_MIN_TAG); - let upper = &mut key_buffer.upper; Self::write_key_prefix(upper, CodecType::Index, table_hash); upper.push(BOUND_MAX_TAG); @@ -362,15 +381,15 @@ impl TableCodec { /// /// Unique index key: /// `{TableName}{INDEX_TAG}{BOUND_MIN_TAG}{IndexID}{BOUND_MIN_TAG}{DataValue}` - pub fn with_index_key( - &self, + pub fn with_index( + &mut self, table_name: &str, index: &Index, tuple_id: Option<&TupleId>, - f: impl FnOnce(&[u8]) -> Result, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, value, _| { Self::write_key_prefix(lower, CodecType::Index, table_hash); lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&index.id.to_le_bytes()); @@ -383,15 +402,20 @@ impl TableCodec { } } - f(lower.as_slice()) + if let Some(tuple_id) = tuple_id { + tuple_id.encode_reference_value(&mut *value)?; + } + + f(lower.as_slice(), value.as_slice()) }) } /// Key: `{TableName}{COLUMN_TAG}{BOUND_MIN_TAG}{ColumnId}`. - pub fn with_column_key( - &self, - col: &ColumnRef, - f: impl FnOnce(&[u8]) -> Result, + pub fn with_column( + &mut self, + col: &ColumnCatalog, + encode_value: bool, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { if let ColumnRelation::Table { column_id, @@ -399,13 +423,18 @@ impl TableCodec { is_temp: false, } = &col.summary().relation { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name.as_ref(), |lower, table_hash, value, refs| { Self::write_key_prefix(lower, CodecType::Column, table_hash); lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&column_id.to_bytes()); - f(lower.as_slice()) + if encode_value { + let _ = refs.push_or_replace(table_name); + Self::encode_column_value_into(col, refs, value)?; + } + + f(lower.as_slice(), value.as_slice()) }) } else { Err(DatabaseError::invalid_column( @@ -416,16 +445,15 @@ impl TableCodec { /// Range bounds covering all column metadata for a table. pub fn with_columns_bound( - &self, + &mut self, table_name: &str, f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, upper, _| { Self::write_key_prefix(lower, CodecType::Column, table_hash); lower.push(BOUND_MIN_TAG); - let upper = &mut key_buffer.upper; Self::write_key_prefix(upper, CodecType::Column, table_hash); upper.push(BOUND_MAX_TAG); @@ -435,16 +463,15 @@ impl TableCodec { /// Range bounds spanning a table's `Column` and `IndexMeta` metadata. pub fn with_table_bound( - &self, + &mut self, table_name: &str, f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, upper, _| { Self::write_key_prefix(lower, CodecType::Column, table_hash); lower.push(BOUND_MIN_TAG); - let upper = &mut key_buffer.upper; Self::write_key_prefix(upper, CodecType::IndexMeta, table_hash); upper.push(BOUND_MAX_TAG); @@ -454,16 +481,15 @@ impl TableCodec { /// Range bounds covering all statistics keys for a table. pub fn with_statistics_bound( - &self, + &mut self, table_name: &str, f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, upper, _| { Self::write_key_prefix(lower, CodecType::Statistics, table_hash); lower.push(BOUND_MIN_TAG); - let upper = &mut key_buffer.upper; Self::write_key_prefix(upper, CodecType::Statistics, table_hash); upper.push(BOUND_MAX_TAG); @@ -473,18 +499,17 @@ impl TableCodec { /// Range bounds covering all statistics keys for one index. pub fn with_statistics_index_bound( - &self, + &mut self, table_name: &str, index_id: IndexId, f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, upper, _| { Self::write_key_prefix(lower, CodecType::Statistics, table_hash); lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&index_id.to_le_bytes()); - let upper = &mut key_buffer.upper; upper.clear(); upper.extend_from_slice(lower); upper.push(BOUND_MAX_TAG); @@ -494,51 +519,62 @@ impl TableCodec { } /// Key: `{TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID}{ROOT_TAG}`. - pub fn with_statistics_meta_key( - &self, + pub fn with_statistics_meta( + &mut self, table_name: &str, index_id: IndexId, - f: impl FnOnce(&[u8]) -> Result, + statistics_meta: Option<&StatisticsMetaRoot>, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, value, refs| { Self::write_key_prefix(lower, CodecType::Statistics, table_hash); lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&index_id.to_le_bytes()); lower.push(StatisticsCodecType::Root.tag()); - f(lower.as_slice()) + if let Some(statistics_meta) = statistics_meta { + Self::encode_statistics_meta_value_into(statistics_meta, refs, value)?; + } + + f(lower.as_slice(), value.as_slice()) }) } /// Key: `{TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID}{SKETCH_META_TAG}`. - pub fn with_statistics_sketch_meta_key( - &self, + pub fn with_statistics_sketch_meta( + &mut self, table_name: &str, index_id: IndexId, - f: impl FnOnce(&[u8]) -> Result, + sketch_meta: Option<&CountMinSketchMeta>, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, value, refs| { Self::write_key_prefix(lower, CodecType::Statistics, table_hash); lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&index_id.to_le_bytes()); lower.push(StatisticsCodecType::SketchMeta.tag()); - f(lower.as_slice()) + if let Some(sketch_meta) = sketch_meta { + Self::encode_statistics_sketch_meta_value_into(sketch_meta, refs, value)?; + } + + f(lower.as_slice(), value.as_slice()) }) } /// Key: `{TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID}{SKETCH_PAGE_TAG}{BOUND_MIN_TAG}{ROW_ID}{BOUND_MIN_TAG}{PAGE_ID}`. - pub fn with_statistics_sketch_page_key( - &self, + pub fn with_statistics_sketch_page( + &mut self, table_name: &str, index_id: IndexId, sketch_page: &CountMinSketchPage, - f: impl FnOnce(&[u8]) -> Result, + encode_value: bool, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, value, refs| { Self::write_key_prefix(lower, CodecType::Statistics, table_hash); lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&index_id.to_le_bytes()); @@ -548,20 +584,25 @@ impl TableCodec { lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&(sketch_page.page_idx() as u32).to_be_bytes()); - f(lower.as_slice()) + if encode_value { + Self::encode_statistics_sketch_page_value_into(sketch_page, refs, value)?; + } + + f(lower.as_slice(), value.as_slice()) }) } /// Key: `{TableName}{STATISTICS_TAG}{BOUND_MIN_TAG}{INDEX_ID}{BUCKET_TAG}{BOUND_MIN_TAG}{ORDINAL}`. - pub fn with_statistics_bucket_key( - &self, + pub fn with_statistics_bucket( + &mut self, table_name: &str, index_id: IndexId, ordinal: u32, - f: impl FnOnce(&[u8]) -> Result, + bucket: Option<&Bucket>, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, value, refs| { Self::write_key_prefix(lower, CodecType::Statistics, table_hash); lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&index_id.to_le_bytes()); @@ -569,89 +610,99 @@ impl TableCodec { lower.push(BOUND_MIN_TAG); lower.extend_from_slice(&ordinal.to_be_bytes()); - f(lower.as_slice()) + if let Some(bucket) = bucket { + Self::encode_statistics_bucket_value_into(bucket, refs, value)?; + } + + f(lower.as_slice(), value.as_slice()) }) } /// Key: `View{BOUND_MIN_TAG}{ViewNameHash}`. - pub fn with_view_key( - &self, + pub fn with_view( + &mut self, view_name: &str, - f: impl FnOnce(&[u8]) -> Result, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(view_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(view_name, |lower, table_hash, value, _| { Self::write_key_prefix(lower, CodecType::View, table_hash); + f(lower.as_slice(), value.as_slice()) + }) + } - f(lower.as_slice()) + /// Key: `View{BOUND_MIN_TAG}{ViewNameHash}` with encoded view payload. + pub fn with_view_value( + &mut self, + view_name: &str, + view: &View, + arena: &A, + f: impl FnOnce(&[u8], &[u8]) -> Result, + ) -> Result { + self.clear_buffers(); + self.with_table_hash_buffers(view_name, |lower, table_hash, value, refs| { + Self::write_key_prefix(lower, CodecType::View, table_hash); + Self::encode_view_value_into(view, refs, value, arena)?; + f(lower.as_slice(), value.as_slice()) }) } /// Range bounds covering all view definitions. pub fn with_view_bound( - &self, + &mut self, f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - let mut key_buffer = self.key_buffer.borrow_mut(); - Self::write_global_bound_prefix( - &mut key_buffer.lower, - VIEW_BYTES.as_slice(), - BOUND_MIN_TAG, - ); - Self::write_global_bound_prefix( - &mut key_buffer.upper, - VIEW_BYTES.as_slice(), - BOUND_MAX_TAG, - ); + self.clear_buffers(); + let (lower, upper, _) = self.slots_mut(); + Self::write_global_bound_prefix(lower, VIEW_BYTES.as_slice(), BOUND_MIN_TAG); + Self::write_global_bound_prefix(upper, VIEW_BYTES.as_slice(), BOUND_MAX_TAG); - f(key_buffer.lower.as_slice(), key_buffer.upper.as_slice()) + f(lower.as_slice(), upper.as_slice()) } /// Key: `Root{BOUND_MIN_TAG}{TableNameHash}`. - pub fn with_root_table_key( - &self, + pub fn with_root_table( + &mut self, table_name: &str, - f: impl FnOnce(&[u8]) -> Result, + meta: Option<&TableMeta>, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, value, refs| { Self::write_key_prefix(lower, CodecType::Root, table_hash); - f(lower.as_slice()) + if let Some(meta) = meta { + Self::encode_root_table_value_into(meta, refs, value)?; + } + + f(lower.as_slice(), value.as_slice()) }) } /// Range bounds covering all root table metadata. pub fn with_root_table_bound( - &self, + &mut self, f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - let mut key_buffer = self.key_buffer.borrow_mut(); - Self::write_global_bound_prefix( - &mut key_buffer.lower, - ROOT_BYTES.as_slice(), - BOUND_MIN_TAG, - ); - Self::write_global_bound_prefix( - &mut key_buffer.upper, - ROOT_BYTES.as_slice(), - BOUND_MAX_TAG, - ); + self.clear_buffers(); + let (lower, upper, _) = self.slots_mut(); + Self::write_global_bound_prefix(lower, ROOT_BYTES.as_slice(), BOUND_MIN_TAG); + Self::write_global_bound_prefix(upper, ROOT_BYTES.as_slice(), BOUND_MAX_TAG); - f(key_buffer.lower.as_slice(), key_buffer.upper.as_slice()) + f(lower.as_slice(), upper.as_slice()) } /// Key: `Hash{BOUND_MIN_TAG}{TableNameHash}`. - pub fn with_table_hash_key( - &self, + pub fn with_table_hash( + &mut self, table_name: &str, - f: impl FnOnce(&[u8]) -> Result, + f: impl FnOnce(&[u8], &[u8]) -> Result, ) -> Result { - self.with_table_hash(table_name, |key_buffer, table_hash| { - let lower = &mut key_buffer.lower; + self.clear_buffers(); + self.with_table_hash_buffers(table_name, |lower, table_hash, value, _| { Self::write_key_prefix(lower, CodecType::Hash, table_hash); - f(lower.as_slice()) + f(lower.as_slice(), value.as_slice()) }) } pub fn decode_tuple_key(bytes: &[u8], pk_ty: &LogicalType) -> Result { @@ -659,28 +710,46 @@ impl TableCodec { } #[inline] - pub fn decode_tuple_into( + pub fn decode_tuple_into( tuple: &mut Tuple, - deserializers: &[TupleValueSerializableImpl], + deserializers: I, tuple_id: Option, bytes: &[u8], total_len: usize, - ) -> Result<(), DatabaseError> { + ) -> Result<(), DatabaseError> + where + I: IntoIterator, + S: Borrow, + { tuple.pk = tuple_id; tuple.deserialize_from_into(deserializers, bytes, total_len) } + fn encode_index_meta_value_into( + index_meta: &IndexMeta, + reference_tables: &mut ReferenceTables, + value: &mut Bytes, + ) -> Result<(), DatabaseError> { + index_meta.encode(value, true, reference_tables, &TableArena::default()) + } + pub fn encode_index_meta_value( - &self, + &mut self, index_meta: &IndexMeta, - ) -> Result, DatabaseError> { - let mut value_bytes = BumpBytes::new_in(&self.arena); - index_meta.encode(&mut value_bytes, true, &mut ReferenceTables::new())?; - Ok(value_bytes) + ) -> Result { + self.clear_buffers(); + let (_, value, reference_tables) = self.slots_mut(); + Self::encode_index_meta_value_into(index_meta, reference_tables, value)?; + Ok(value.clone()) } pub fn decode_index_meta(bytes: &[u8]) -> Result { - IndexMeta::decode::(&mut Cursor::new(bytes), None, &ReferenceTables::new()) + IndexMeta::decode::( + &mut Cursor::new(bytes), + None, + &ReferenceTables::new(), + &mut TableArena::default(), + ) } pub fn decode_index_key( @@ -694,83 +763,138 @@ impl TableCodec { } pub fn decode_index(bytes: &[u8]) -> Result { - Ok(bincode::deserialize_from(&mut Cursor::new(bytes))?) + DataValue::decode_reference_value(&mut Cursor::new(bytes)) } - pub fn encode_column_value( - &self, - col: &ColumnRef, + fn encode_column_value_into( + col: &ColumnCatalog, reference_tables: &mut ReferenceTables, - ) -> Result, DatabaseError> { - let mut column_bytes = BumpBytes::new_in(&self.arena); - col.encode(&mut column_bytes, true, reference_tables)?; - Ok(column_bytes) + value: &mut Bytes, + ) -> Result<(), DatabaseError> { + col.encode(value, true, reference_tables, &TableArena::default()) + } + + pub fn encode_column_value(&mut self, col: &ColumnCatalog) -> Result { + self.clear_buffers(); + let (_, value, reference_tables) = self.slots_mut(); + Self::encode_column_value_into(col, reference_tables, value)?; + Ok(value.clone()) } pub fn decode_column( reader: &mut R, reference_tables: &ReferenceTables, - ) -> Result { + ) -> Result { // `TableCache` is not theoretically used in `table_collect` because `ColumnCatalog` should not depend on other Column - ColumnRef::decode::(reader, None, reference_tables) + ColumnCatalog::decode::(reader, None, reference_tables, &mut TableArena::default()) + } + + fn encode_statistics_meta_value_into( + statistics_meta: &StatisticsMetaRoot, + reference_tables: &mut ReferenceTables, + value: &mut Bytes, + ) -> Result<(), DatabaseError> { + statistics_meta.encode(value, true, reference_tables, &TableArena::default()) } pub fn encode_statistics_meta_value( - &self, + &mut self, statistics_meta: &StatisticsMetaRoot, - ) -> Result, DatabaseError> { - let mut value = BumpBytes::new_in(&self.arena); - statistics_meta.encode(&mut value, true, &mut ReferenceTables::new())?; - Ok(value) + ) -> Result { + self.clear_buffers(); + let (_, value, reference_tables) = self.slots_mut(); + Self::encode_statistics_meta_value_into(statistics_meta, reference_tables, value)?; + Ok(value.clone()) } pub fn decode_statistics_meta( bytes: &[u8], ) -> Result { - StatisticsMetaRoot::decode::(&mut Cursor::new(bytes), None, &ReferenceTables::new()) + StatisticsMetaRoot::decode::( + &mut Cursor::new(bytes), + None, + &ReferenceTables::new(), + &mut TableArena::default(), + ) + } + + fn encode_statistics_sketch_meta_value_into( + sketch_meta: &CountMinSketchMeta, + reference_tables: &mut ReferenceTables, + value: &mut Bytes, + ) -> Result<(), DatabaseError> { + sketch_meta.encode(value, true, reference_tables, &TableArena::default()) } pub fn encode_statistics_sketch_meta_value( - &self, + &mut self, sketch_meta: &CountMinSketchMeta, - ) -> Result, DatabaseError> { - let mut value = BumpBytes::new_in(&self.arena); - sketch_meta.encode(&mut value, true, &mut ReferenceTables::new())?; - Ok(value) + ) -> Result { + self.clear_buffers(); + let (_, value, reference_tables) = self.slots_mut(); + Self::encode_statistics_sketch_meta_value_into(sketch_meta, reference_tables, value)?; + Ok(value.clone()) } pub fn decode_statistics_sketch_meta( bytes: &[u8], ) -> Result { - CountMinSketchMeta::decode::(&mut Cursor::new(bytes), None, &ReferenceTables::new()) + CountMinSketchMeta::decode::( + &mut Cursor::new(bytes), + None, + &ReferenceTables::new(), + &mut TableArena::default(), + ) + } + + fn encode_statistics_sketch_page_value_into( + sketch_page: &CountMinSketchPage, + reference_tables: &mut ReferenceTables, + value: &mut Bytes, + ) -> Result<(), DatabaseError> { + sketch_page.encode(value, true, reference_tables, &TableArena::default()) } pub fn encode_statistics_sketch_page_value( - &self, + &mut self, sketch_page: &CountMinSketchPage, - ) -> Result, DatabaseError> { - let mut value = BumpBytes::new_in(&self.arena); - sketch_page.encode(&mut value, true, &mut ReferenceTables::new())?; - Ok(value) + ) -> Result { + self.clear_buffers(); + let (_, value, reference_tables) = self.slots_mut(); + Self::encode_statistics_sketch_page_value_into(sketch_page, reference_tables, value)?; + Ok(value.clone()) } pub fn decode_statistics_sketch_page( bytes: &[u8], ) -> Result { - CountMinSketchPage::decode::(&mut Cursor::new(bytes), None, &ReferenceTables::new()) + CountMinSketchPage::decode::( + &mut Cursor::new(bytes), + None, + &ReferenceTables::new(), + &mut TableArena::default(), + ) + } + + fn encode_statistics_bucket_value_into( + bucket: &Bucket, + reference_tables: &mut ReferenceTables, + value: &mut Bytes, + ) -> Result<(), DatabaseError> { + bucket.encode(value, true, reference_tables, &TableArena::default()) } pub fn encode_statistics_bucket_value( - &self, + &mut self, bucket: &Bucket, - ) -> Result, DatabaseError> { - let mut value = BumpBytes::new_in(&self.arena); - bucket.encode(&mut value, true, &mut ReferenceTables::new())?; - Ok(value) + ) -> Result { + self.clear_buffers(); + let (_, value, reference_tables) = self.slots_mut(); + Self::encode_statistics_bucket_value_into(bucket, reference_tables, value)?; + Ok(value.clone()) } pub(crate) fn decode_statistics_codec_type( - &self, key: &[u8], ) -> Result { let prefix_len = TUPLE_KEY_PREFIX_LEN + INDEX_ID_LEN; @@ -784,7 +908,12 @@ impl TableCodec { } pub fn decode_statistics_bucket(bytes: &[u8]) -> Result { - Bucket::decode::(&mut Cursor::new(bytes), None, &ReferenceTables::new()) + Bucket::decode::( + &mut Cursor::new(bytes), + None, + &ReferenceTables::new(), + &mut TableArena::default(), + ) } pub fn decode_statistics_bucket_ordinal(bytes: &[u8]) -> Result { @@ -798,20 +927,24 @@ impl TableCodec { Ok(u32::from_be_bytes(ordinal_bytes.try_into().unwrap())) } - pub fn encode_view_value(&self, view: &View) -> Result, DatabaseError> { - let mut reference_tables = ReferenceTables::new(); - let mut bytes = BumpBytes::new_in(&self.arena); + fn encode_view_value_into( + view: &View, + reference_tables: &mut ReferenceTables, + bytes: &mut Bytes, + arena: &impl MetaArena, + ) -> Result<(), DatabaseError> { + bytes.clear(); bytes.resize(4, 0u8); let reference_tables_pos = { - view.encode(&mut bytes, false, &mut reference_tables)?; + view.encode(&mut *bytes, false, reference_tables, arena)?; let pos = bytes.len(); - reference_tables.to_raw(&mut bytes)?; + reference_tables.to_raw(&mut *bytes)?; pos }; bytes[..4].copy_from_slice(&(reference_tables_pos as u32).to_le_bytes()); - Ok(bytes) + Ok(()) } pub fn decode_view( @@ -819,6 +952,7 @@ impl TableCodec { drive: (&T, &TableCache), scala_functions: &ScalaFunctions, table_functions: &TableFunctions, + arena: &mut impl MetaArena, ) -> Result { let mut cursor = Cursor::new(bytes); let reference_tables_pos = { @@ -832,22 +966,33 @@ impl TableCodec { let context = ReferenceDecodeContext::with_functions(Some(drive), scala_functions, table_functions); - View::decode(&mut cursor, Some(&context), &reference_tables) + View::decode(&mut cursor, Some(&context), &reference_tables, arena) } - pub fn encode_root_table_value( - &self, + fn encode_root_table_value_into( meta: &TableMeta, - ) -> Result, DatabaseError> { - let mut meta_bytes = BumpBytes::new_in(&self.arena); - meta.encode(&mut meta_bytes, true, &mut ReferenceTables::new())?; - Ok(meta_bytes) + reference_tables: &mut ReferenceTables, + value: &mut Bytes, + ) -> Result<(), DatabaseError> { + meta.encode(value, true, reference_tables, &TableArena::default()) + } + + pub fn encode_root_table_value(&mut self, meta: &TableMeta) -> Result { + self.clear_buffers(); + let (_, value, reference_tables) = self.slots_mut(); + Self::encode_root_table_value_into(meta, reference_tables, value)?; + Ok(value.clone()) } pub fn decode_root_table(bytes: &[u8]) -> Result { let mut bytes = Cursor::new(bytes); - TableMeta::decode::(&mut bytes, None, &ReferenceTables::new()) + TableMeta::decode::( + &mut bytes, + None, + &ReferenceTables::new(), + &mut TableArena::default(), + ) } } @@ -855,12 +1000,11 @@ impl TableCodec { mod tests { use crate::binder::test::build_t1_table; use crate::catalog::view::View; - use crate::catalog::{ - ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation, TableCatalog, TableMeta, - }; + use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRelation, TableCatalog, TableMeta}; use crate::errors::DatabaseError; use crate::optimizer::core::histogram::HistogramBuilder; use crate::optimizer::core::statistics_meta::StatisticsMeta; + use crate::planner::{PlanArena, TableArenaCell}; use crate::serdes::ReferenceTables; use crate::storage::rocksdb::RocksTransaction; use crate::storage::table_codec::{Bytes, TableCodec}; @@ -877,7 +1021,7 @@ mod tests { use std::sync::Arc; use ulid::Ulid; - fn build_table_codec() -> TableCatalog { + fn build_table_codec(table_arena: &TableArenaCell) -> TableCatalog { let columns = vec![ ColumnCatalog::new( "c1".into(), @@ -890,28 +1034,30 @@ mod tests { ColumnDesc::new(LogicalType::Decimal(None, None), None, false, None).unwrap(), ), ]; - TableCatalog::new("t1".to_string().into(), columns).unwrap() + TableCatalog::new("t1".to_string().into(), columns, table_arena.borrow_mut()).unwrap() } #[test] fn test_table_codec_tuple() -> Result<(), DatabaseError> { - let table_codec = TableCodec::default(); - let table_catalog = build_table_codec(); + let table_arena = TableArenaCell::default(); + let table_catalog = build_table_codec(&table_arena); + let plan_arena = PlanArena::new(&table_arena); let expected = Tuple::new( Some(DataValue::Int32(0)), vec![DataValue::Int32(0), DataValue::Decimal(Decimal::new(1, 0))], ); - let bytes = expected.serialize_to( + let mut bytes = Vec::new(); + expected.serialize_to( &[ LogicalType::Integer.serializable(), LogicalType::Decimal(None, None).serializable(), ], - table_codec.arena(), + &mut bytes, )?; let deserializers = table_catalog .columns() - .map(|column| column.datatype().serializable()) + .map(|column| plan_arena.column(*column).datatype().serializable()) .collect_vec(); let mut tuple = Tuple::default(); @@ -929,8 +1075,9 @@ mod tests { #[test] fn test_root_catalog() { - let table_codec = TableCodec::default(); - let table_catalog = build_table_codec(); + let mut table_codec = TableCodec::default(); + let table_arena = TableArenaCell::default(); + let table_catalog = build_table_codec(&table_arena); let bytes = table_codec .encode_root_table_value(&TableMeta { table_name: table_catalog.name.clone(), @@ -944,7 +1091,7 @@ mod tests { #[test] fn test_table_codec_statistics_meta() -> Result<(), DatabaseError> { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let index_meta = IndexMeta { id: 0, column_ids: vec![Ulid::new()], @@ -988,10 +1135,12 @@ mod tests { TableCodec::decode_statistics_sketch_page::(&sketch_page_bytes)?; assert_eq!(decoded_sketch_page.counters(), first_sketch_page.counters()); - let bucket0_key = table_codec - .with_statistics_bucket_key("t1", 0, 0, |key| Ok::<_, DatabaseError>(key.to_vec()))?; - let bucket1_key = table_codec - .with_statistics_bucket_key("t1", 0, 1, |key| Ok::<_, DatabaseError>(key.to_vec()))?; + let bucket0_key = table_codec.with_statistics_bucket("t1", 0, 0, None, |key, _| { + Ok::<_, DatabaseError>(key.to_vec()) + })?; + let bucket1_key = table_codec.with_statistics_bucket("t1", 0, 1, None, |key, _| { + Ok::<_, DatabaseError>(key.to_vec()) + })?; assert!(bucket0_key < bucket1_key); let (bucket0_min, bucket0_max) = @@ -1019,7 +1168,7 @@ mod tests { #[test] fn test_table_codec_index_meta() -> Result<(), DatabaseError> { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let index_meta = IndexMeta { id: 0, column_ids: vec![Ulid::new()], @@ -1042,7 +1191,8 @@ mod tests { #[test] fn test_table_codec_index() -> Result<(), DatabaseError> { let tuple_id = DataValue::Int32(0); - let bytes = bincode::serialize(&tuple_id)?; + let mut bytes = Vec::new(); + tuple_id.encode_reference_value(&mut bytes)?; assert_eq!(TableCodec::decode_index(&bytes)?, tuple_id); @@ -1061,93 +1211,152 @@ mod tests { table_name: "t1".to_string().into(), is_temp: false, }; - let col = ColumnRef::from(col); + let expected_col = col.clone(); let mut reference_tables = ReferenceTables::new(); + reference_tables.push_or_replace(&"t1".to_string().into()); - let table_codec = TableCodec::default(); - let bytes = table_codec - .encode_column_value(&col, &mut reference_tables) - .unwrap(); + let mut table_codec = TableCodec::default(); + let bytes = table_codec.encode_column_value(&col).unwrap(); let mut cursor = Cursor::new(bytes); let decode_col = TableCodec::decode_column::(&mut cursor, &reference_tables)?; - assert_eq!(decode_col, col); + assert_eq!(decode_col, expected_col); Ok(()) } #[test] fn test_table_codec_view() -> Result<(), DatabaseError> { - let table_codec = TableCodec::default(); + fn normalize_explain(explain: String) -> String { + let mut normalized = String::with_capacity(explain.len()); + let bytes = explain.as_bytes(); + let mut i = 0; + while i < bytes.len() { + if bytes[i..].starts_with(b"pos: ") { + normalized.push_str("pos: _"); + i += b"pos: ".len(); + while i < bytes.len() && bytes[i].is_ascii_digit() { + i += 1; + } + } else if bytes[i] == b'#' { + normalized.push_str("#_"); + i += 1; + while i < bytes.len() && bytes[i].is_ascii_digit() { + i += 1; + } + } else { + normalized.push(bytes[i] as char); + i += 1; + } + } + normalized + } + + let mut table_codec = TableCodec::default(); let table_state = build_t1_table()?; let scala_functions = Default::default(); let table_functions = Default::default(); + let build_view = |name: &str, sql: &str| -> Result<(View, PlanArena<'_>), DatabaseError> { + let mut plan_arena = PlanArena::new(&table_state.table_arena); + let mut plan = table_state.plan_with_arena(sql, &mut plan_arena)?; + let schema = plan.output_schema(&mut plan_arena).clone(); + Ok(( + View { + name: name.to_string().into(), + plan: Box::new(plan), + schema, + }, + plan_arena, + )) + }; // Subquery { println!("==== Subquery"); - let plan = table_state - .plan("select * from t1 where c1 in (select c1 from t1 where c1 > 1)")?; - println!("{plan:#?}"); - let view = View { - name: "view_subquery".to_string().into(), - plan: Box::new(plan), - }; - let bytes = table_codec.encode_view_value(&view)?; + let (view, mut plan_arena) = build_view( + "view_subquery", + "select * from t1 where c1 in (select c1 from t1 where c1 > 1)", + )?; + println!("{:#?}", view.plan); let transaction = table_state.storage.transaction()?; + let mut decode_arena = PlanArena::new(&table_state.table_arena); + let decoded = table_codec.with_view_value( + view.name.as_ref(), + &view, + &plan_arena, + |_, value| { + TableCodec::decode_view( + value, + (&transaction, &table_state.table_cache), + &scala_functions, + &table_functions, + &mut decode_arena, + ) + }, + )?; assert_eq!( - view, - TableCodec::decode_view( - &bytes, - (&transaction, &table_state.table_cache), - &scala_functions, - &table_functions, - )? + normalize_explain(view.plan.explain(&mut plan_arena, 0)), + normalize_explain(decoded.plan.explain(&mut decode_arena, 0)) ); + assert_eq!(view.schema.len(), decoded.schema.len()); } // No Join { println!("==== No Join"); - let plan = table_state.plan("select * from t1 where c1 > 1")?; - let view = View { - name: "view_filter".to_string().into(), - plan: Box::new(plan), - }; - let bytes = table_codec.encode_view_value(&view)?; + let (view, mut plan_arena) = + build_view("view_filter", "select * from t1 where c1 > 1")?; let transaction = table_state.storage.transaction()?; + let mut decode_arena = PlanArena::new(&table_state.table_arena); + let decoded = table_codec.with_view_value( + view.name.as_ref(), + &view, + &plan_arena, + |_, value| { + TableCodec::decode_view( + value, + (&transaction, &table_state.table_cache), + &scala_functions, + &table_functions, + &mut decode_arena, + ) + }, + )?; assert_eq!( - view, - TableCodec::decode_view( - &bytes, - (&transaction, &table_state.table_cache), - &scala_functions, - &table_functions, - )? + normalize_explain(view.plan.explain(&mut plan_arena, 0)), + normalize_explain(decoded.plan.explain(&mut decode_arena, 0)) ); + assert_eq!(view.schema.len(), decoded.schema.len()); } // Join { println!("==== Join"); - let plan = table_state.plan("select * from t1 left join t2 on c1 = c3")?; - let view = View { - name: "view_join".to_string().into(), - plan: Box::new(plan), - }; - let bytes = table_codec.encode_view_value(&view)?; + let (view, mut plan_arena) = + build_view("view_join", "select * from t1 left join t2 on c1 = c3")?; let transaction = table_state.storage.transaction()?; + let mut decode_arena = PlanArena::new(&table_state.table_arena); + let decoded = table_codec.with_view_value( + view.name.as_ref(), + &view, + &plan_arena, + |_, value| { + TableCodec::decode_view( + value, + (&transaction, &table_state.table_cache), + &scala_functions, + &table_functions, + &mut decode_arena, + ) + }, + )?; assert_eq!( - view, - TableCodec::decode_view( - &bytes, - (&transaction, &table_state.table_cache), - &scala_functions, - &table_functions, - )? + normalize_explain(view.plan.explain(&mut plan_arena, 0)), + normalize_explain(decoded.plan.explain(&mut decode_arena, 0)) ); + assert_eq!(view.schema.len(), decoded.schema.len()); } Ok(()) @@ -1156,9 +1365,10 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_table_codec_column_bound() { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let op = |col_id: usize, table_name: &str| { + let mut table_codec = TableCodec::default(); let mut col = ColumnCatalog::new( "".to_string(), false, @@ -1172,9 +1382,7 @@ mod tests { }; table_codec - .with_column_key(&ColumnRef::from(col), |key| { - Ok::<_, DatabaseError>(key.to_vec()) - }) + .with_column(&col, false, |key, _| Ok::<_, DatabaseError>(key.to_vec())) .unwrap() }; @@ -1213,9 +1421,10 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_table_codec_index_meta_bound() { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let op = |index_id: usize, table_name: &str| { + let mut table_codec = TableCodec::default(); let index_meta = IndexMeta { id: index_id as u32, column_ids: vec![], @@ -1227,7 +1436,7 @@ mod tests { }; table_codec - .with_index_meta_key(table_name, index_meta.id, |key| { + .with_index_meta(table_name, index_meta.id, None, |key, _| { Ok::<_, DatabaseError>(key.to_vec()) }) .unwrap() @@ -1268,16 +1477,23 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_table_codec_index_bound() { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let column = ColumnCatalog::new( "".to_string(), false, ColumnDesc::new(LogicalType::Boolean, None, false, None).unwrap(), ); - let table_catalog = TableCatalog::new("T0".to_string().into(), vec![column]).unwrap(); + let table_arena = TableArenaCell::default(); + let table_catalog = TableCatalog::new( + "T0".to_string().into(), + vec![column], + table_arena.borrow_mut(), + ) + .unwrap(); let op = |value: DataValue, index_id: usize, table_name: &str| { + let mut table_codec = TableCodec::default(); let value = Arc::new(value); let index = Index::new( index_id as u32, @@ -1286,7 +1502,7 @@ mod tests { ); table_codec - .with_index_key(table_name, &index, None, |key| { + .with_index(table_name, &index, None, |key, _| { Ok::<_, DatabaseError>(key.to_vec()) }) .unwrap() @@ -1327,9 +1543,10 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_table_codec_index_all_bound() { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let op = |value: DataValue, index_id: usize, table_name: &str| { + let mut table_codec = TableCodec::default(); let value = Arc::new(value); let index = Index::new( index_id as u32, @@ -1338,7 +1555,7 @@ mod tests { ); table_codec - .with_index_key(table_name, &index, None, |key| { + .with_index(table_name, &index, None, |key, _| { Ok::<_, DatabaseError>(key.to_vec()) }) .unwrap() @@ -1379,11 +1596,12 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_table_codec_tuple_bound() { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let op = |tuple_id: DataValue, table_name: &str| { + let mut table_codec = TableCodec::default(); table_codec - .with_tuple_key(table_name, &Arc::new(tuple_id), |key| { + .with_tuple(table_name, &Arc::new(tuple_id), None, |key, _| { Ok::<_, DatabaseError>(key.to_vec()) }) .unwrap() @@ -1424,11 +1642,14 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_root_codec_name_bound() { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let mut set: BTreeSet = BTreeSet::new(); let op = |table_name: &str| { + let mut table_codec = TableCodec::default(); table_codec - .with_root_table_key(table_name, |key| Ok::<_, DatabaseError>(key.to_vec())) + .with_root_table(table_name, None, |key, _| { + Ok::<_, DatabaseError>(key.to_vec()) + }) .unwrap() }; @@ -1460,11 +1681,12 @@ mod tests { #[test] #[allow(clippy::mutable_key_type)] fn test_view_codec_name_bound() { - let table_codec = TableCodec::default(); + let mut table_codec = TableCodec::default(); let mut set = BTreeSet::new(); let op = |view_name: &str| { + let mut table_codec = TableCodec::default(); table_codec - .with_view_key(view_name, |key| Ok::<_, DatabaseError>(key.to_vec())) + .with_view(view_name, |key, _| Ok::<_, DatabaseError>(key.to_vec())) .unwrap() }; diff --git a/src/types/evaluator/binary.rs b/src/types/evaluator/binary.rs index e16eec9d..f50212fa 100644 --- a/src/types/evaluator/binary.rs +++ b/src/types/evaluator/binary.rs @@ -15,8 +15,11 @@ use crate::errors::DatabaseError; use crate::expression::BinaryOperator; use crate::types::evaluator::boolean::*; +#[cfg(feature = "time")] use crate::types::evaluator::date::*; +#[cfg(feature = "time")] use crate::types::evaluator::datetime::*; +#[cfg(feature = "decimal")] use crate::types::evaluator::decimal::*; use crate::types::evaluator::float32::*; use crate::types::evaluator::float64::*; @@ -24,312 +27,630 @@ use crate::types::evaluator::int16::*; use crate::types::evaluator::int32::*; use crate::types::evaluator::int64::*; use crate::types::evaluator::int8::*; -use crate::types::evaluator::null::NullBinaryEvaluator; +use crate::types::evaluator::null::*; +#[cfg(feature = "time")] use crate::types::evaluator::time32::*; +#[cfg(feature = "time")] use crate::types::evaluator::time64::*; -use crate::types::evaluator::tuple::{ - TupleEqBinaryEvaluator, TupleGtBinaryEvaluator, TupleGtEqBinaryEvaluator, - TupleLtBinaryEvaluator, TupleLtEqBinaryEvaluator, TupleNotEqBinaryEvaluator, -}; +use crate::types::evaluator::tuple::*; use crate::types::evaluator::uint16::*; use crate::types::evaluator::uint32::*; use crate::types::evaluator::uint64::*; use crate::types::evaluator::uint8::*; use crate::types::evaluator::utf8::*; -use crate::types::evaluator::BinaryEvaluatorBox; +use crate::types::evaluator::{BinaryEvaluatorParams, BinaryEvaluatorRef}; use crate::types::LogicalType; use paste::paste; use std::borrow::Cow; -use std::sync::Arc; -macro_rules! box_binary { - ($ty:expr, $op:expr, $evaluator:expr) => { - Ok(BinaryEvaluatorBox::new( - Arc::new($evaluator), - $ty.clone(), - $op, - )) - }; +const NUMERIC_PLUS_OFFSET: u16 = 0; +const NUMERIC_MINUS_OFFSET: u16 = 1; +const NUMERIC_MULTIPLY_OFFSET: u16 = 2; +const NUMERIC_DIVIDE_OFFSET: u16 = 3; +const NUMERIC_GT_OFFSET: u16 = 4; +const NUMERIC_GT_EQ_OFFSET: u16 = 5; +const NUMERIC_LT_OFFSET: u16 = 6; +const NUMERIC_LT_EQ_OFFSET: u16 = 7; +const NUMERIC_EQ_OFFSET: u16 = 8; +const NUMERIC_NOT_EQ_OFFSET: u16 = 9; +const NUMERIC_MODULO_OFFSET: u16 = 10; +const NUMERIC_OPS_LEN: u16 = NUMERIC_MODULO_OFFSET + 1; + +#[cfg(feature = "time")] +const TIME_PLUS_OFFSET: u16 = 0; +#[cfg(feature = "time")] +const TIME_MINUS_OFFSET: u16 = 1; +#[cfg(feature = "time")] +const TIME_GT_OFFSET: u16 = 2; +#[cfg(feature = "time")] +const TIME_GT_EQ_OFFSET: u16 = 3; +#[cfg(feature = "time")] +const TIME_LT_OFFSET: u16 = 4; +#[cfg(feature = "time")] +const TIME_LT_EQ_OFFSET: u16 = 5; +#[cfg(feature = "time")] +const TIME_EQ_OFFSET: u16 = 6; +#[cfg(feature = "time")] +const TIME_NOT_EQ_OFFSET: u16 = 7; +const TIME_OPS_LEN: u16 = 8; + +#[cfg(feature = "time")] +const TIMESTAMP_GT_OFFSET: u16 = 0; +#[cfg(feature = "time")] +const TIMESTAMP_GT_EQ_OFFSET: u16 = 1; +#[cfg(feature = "time")] +const TIMESTAMP_LT_OFFSET: u16 = 2; +#[cfg(feature = "time")] +const TIMESTAMP_LT_EQ_OFFSET: u16 = 3; +#[cfg(feature = "time")] +const TIMESTAMP_EQ_OFFSET: u16 = 4; +#[cfg(feature = "time")] +const TIMESTAMP_NOT_EQ_OFFSET: u16 = 5; +const TIMESTAMP_OPS_LEN: u16 = 6; + +const BOOLEAN_AND_OFFSET: u16 = 0; +const BOOLEAN_OR_OFFSET: u16 = 1; +const BOOLEAN_EQ_OFFSET: u16 = 2; +const BOOLEAN_NOT_EQ_OFFSET: u16 = 3; +const BOOLEAN_OPS_LEN: u16 = BOOLEAN_NOT_EQ_OFFSET + 1; + +const UTF8_GT_OFFSET: u16 = 0; +const UTF8_LT_OFFSET: u16 = 1; +const UTF8_GT_EQ_OFFSET: u16 = 2; +const UTF8_LT_EQ_OFFSET: u16 = 3; +const UTF8_EQ_OFFSET: u16 = 4; +const UTF8_NOT_EQ_OFFSET: u16 = 5; +const UTF8_STRING_CONCAT_OFFSET: u16 = 6; +const UTF8_LIKE_OFFSET: u16 = 7; +const UTF8_NOT_LIKE_OFFSET: u16 = 8; +const UTF8_OPS_LEN: u16 = UTF8_NOT_LIKE_OFFSET + 1; + +const SQL_NULL_OPS_LEN: u16 = 1; + +const TUPLE_EQ_OFFSET: u16 = 0; +const TUPLE_NOT_EQ_OFFSET: u16 = 1; +const TUPLE_GT_OFFSET: u16 = 2; +const TUPLE_GT_EQ_OFFSET: u16 = 3; +const TUPLE_LT_OFFSET: u16 = 4; +const TUPLE_LT_EQ_OFFSET: u16 = 5; + +const BINARY_INT8_BASE: u16 = 0; +const BINARY_INT16_BASE: u16 = BINARY_INT8_BASE + NUMERIC_OPS_LEN; +const BINARY_INT32_BASE: u16 = BINARY_INT16_BASE + NUMERIC_OPS_LEN; +const BINARY_INT64_BASE: u16 = BINARY_INT32_BASE + NUMERIC_OPS_LEN; +const BINARY_UINT8_BASE: u16 = BINARY_INT64_BASE + NUMERIC_OPS_LEN; +const BINARY_UINT16_BASE: u16 = BINARY_UINT8_BASE + NUMERIC_OPS_LEN; +const BINARY_UINT32_BASE: u16 = BINARY_UINT16_BASE + NUMERIC_OPS_LEN; +const BINARY_UINT64_BASE: u16 = BINARY_UINT32_BASE + NUMERIC_OPS_LEN; +const BINARY_FLOAT32_BASE: u16 = BINARY_UINT64_BASE + NUMERIC_OPS_LEN; +const BINARY_FLOAT64_BASE: u16 = BINARY_FLOAT32_BASE + NUMERIC_OPS_LEN; +const BINARY_DATE_BASE: u16 = BINARY_FLOAT64_BASE + NUMERIC_OPS_LEN; +const BINARY_DATETIME_BASE: u16 = BINARY_DATE_BASE + NUMERIC_OPS_LEN; +const BINARY_DECIMAL_BASE: u16 = BINARY_DATETIME_BASE + NUMERIC_OPS_LEN; +const BINARY_TIME_BASE: u16 = BINARY_DECIMAL_BASE + NUMERIC_OPS_LEN; +const BINARY_TIME64_BASE: u16 = BINARY_TIME_BASE + TIME_OPS_LEN; +const BINARY_BOOLEAN_BASE: u16 = BINARY_TIME64_BASE + TIMESTAMP_OPS_LEN; +const BINARY_UTF8_BASE: u16 = BINARY_BOOLEAN_BASE + BOOLEAN_OPS_LEN; +const BINARY_SQL_NULL: u16 = BINARY_UTF8_BASE + UTF8_OPS_LEN; +const BINARY_TUPLE_BASE: u16 = BINARY_SQL_NULL + SQL_NULL_OPS_LEN; + +// Evaluator positions are serialized ABI. Do not reorder or reuse existing +// positions; only append new positions at the end of the current layout. + +const fn binary_pos(base: u16, offset: u16) -> u16 { + base + offset } -macro_rules! numeric_binary_evaluator { - ($value_type:ident, $op:expr, $ty:expr) => { - paste! { - match $op { - BinaryOperator::Plus => box_binary!($ty, $op, [<$value_type PlusBinaryEvaluator>]), - BinaryOperator::Minus => box_binary!($ty, $op, [<$value_type MinusBinaryEvaluator>]), - BinaryOperator::Multiply => box_binary!($ty, $op, [<$value_type MultiplyBinaryEvaluator>]), - BinaryOperator::Divide => box_binary!($ty, $op, [<$value_type DivideBinaryEvaluator>]), - BinaryOperator::Gt => box_binary!($ty, $op, [<$value_type GtBinaryEvaluator>]), - BinaryOperator::GtEq => box_binary!($ty, $op, [<$value_type GtEqBinaryEvaluator>]), - BinaryOperator::Lt => box_binary!($ty, $op, [<$value_type LtBinaryEvaluator>]), - BinaryOperator::LtEq => box_binary!($ty, $op, [<$value_type LtEqBinaryEvaluator>]), - BinaryOperator::Eq => box_binary!($ty, $op, [<$value_type EqBinaryEvaluator>]), - BinaryOperator::NotEq => box_binary!($ty, $op, [<$value_type NotEqBinaryEvaluator>]), - BinaryOperator::Modulo => box_binary!($ty, $op, [<$value_type ModBinaryEvaluator>]), - _ => Err(DatabaseError::UnsupportedBinaryOperator($ty.clone(), $op)), - } - } - }; +fn unit_binary_ref(pos: u16) -> Result { + Ok(BinaryEvaluatorRef::new(pos, BinaryEvaluatorParams::Unit)) +} + +fn numeric_binary_pos( + base: u16, + ty: &LogicalType, + op: BinaryOperator, +) -> Result { + Ok(base + + match op { + BinaryOperator::Plus => NUMERIC_PLUS_OFFSET, + BinaryOperator::Minus => NUMERIC_MINUS_OFFSET, + BinaryOperator::Multiply => NUMERIC_MULTIPLY_OFFSET, + BinaryOperator::Divide => NUMERIC_DIVIDE_OFFSET, + BinaryOperator::Gt => NUMERIC_GT_OFFSET, + BinaryOperator::GtEq => NUMERIC_GT_EQ_OFFSET, + BinaryOperator::Lt => NUMERIC_LT_OFFSET, + BinaryOperator::LtEq => NUMERIC_LT_EQ_OFFSET, + BinaryOperator::Eq => NUMERIC_EQ_OFFSET, + BinaryOperator::NotEq => NUMERIC_NOT_EQ_OFFSET, + BinaryOperator::Modulo => NUMERIC_MODULO_OFFSET, + _ => return Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), + }) +} + +fn numeric_binary_ref( + base: u16, + ty: &LogicalType, + op: BinaryOperator, +) -> Result { + unit_binary_ref(numeric_binary_pos(base, ty, op)?) } pub fn binary_create( ty: Cow<'_, LogicalType>, op: BinaryOperator, -) -> Result { +) -> Result { let ty = ty.as_ref(); match ty { - LogicalType::Tinyint => numeric_binary_evaluator!(Int8, op, ty), - LogicalType::Smallint => numeric_binary_evaluator!(Int16, op, ty), - LogicalType::Integer => numeric_binary_evaluator!(Int32, op, ty), - LogicalType::Bigint => numeric_binary_evaluator!(Int64, op, ty), - LogicalType::UTinyint => numeric_binary_evaluator!(UInt8, op, ty), - LogicalType::USmallint => numeric_binary_evaluator!(UInt16, op, ty), - LogicalType::UInteger => numeric_binary_evaluator!(UInt32, op, ty), - LogicalType::UBigint => numeric_binary_evaluator!(UInt64, op, ty), - LogicalType::Float => numeric_binary_evaluator!(Float32, op, ty), - LogicalType::Double => numeric_binary_evaluator!(Float64, op, ty), - LogicalType::Date => numeric_binary_evaluator!(Date, op, ty), - LogicalType::DateTime => numeric_binary_evaluator!(DateTime, op, ty), + LogicalType::Tinyint => numeric_binary_ref(BINARY_INT8_BASE, ty, op), + LogicalType::Smallint => numeric_binary_ref(BINARY_INT16_BASE, ty, op), + LogicalType::Integer => numeric_binary_ref(BINARY_INT32_BASE, ty, op), + LogicalType::Bigint => numeric_binary_ref(BINARY_INT64_BASE, ty, op), + LogicalType::UTinyint => numeric_binary_ref(BINARY_UINT8_BASE, ty, op), + LogicalType::USmallint => numeric_binary_ref(BINARY_UINT16_BASE, ty, op), + LogicalType::UInteger => numeric_binary_ref(BINARY_UINT32_BASE, ty, op), + LogicalType::UBigint => numeric_binary_ref(BINARY_UINT64_BASE, ty, op), + LogicalType::Float => numeric_binary_ref(BINARY_FLOAT32_BASE, ty, op), + LogicalType::Double => numeric_binary_ref(BINARY_FLOAT64_BASE, ty, op), + #[cfg(feature = "time")] + LogicalType::Date => numeric_binary_ref(BINARY_DATE_BASE, ty, op), + #[cfg(not(feature = "time"))] + LogicalType::Date => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), + #[cfg(feature = "time")] + LogicalType::DateTime => numeric_binary_ref(BINARY_DATETIME_BASE, ty, op), + #[cfg(not(feature = "time"))] + LogicalType::DateTime => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), + #[cfg(feature = "time")] LogicalType::Time(_) => match op { - BinaryOperator::Plus => box_binary!(ty, op, TimePlusBinaryEvaluator), - BinaryOperator::Minus => box_binary!(ty, op, TimeMinusBinaryEvaluator), - BinaryOperator::Gt => box_binary!(ty, op, TimeGtBinaryEvaluator), - BinaryOperator::GtEq => box_binary!(ty, op, TimeGtEqBinaryEvaluator), - BinaryOperator::Lt => box_binary!(ty, op, TimeLtBinaryEvaluator), - BinaryOperator::LtEq => box_binary!(ty, op, TimeLtEqBinaryEvaluator), - BinaryOperator::Eq => box_binary!(ty, op, TimeEqBinaryEvaluator), - BinaryOperator::NotEq => box_binary!(ty, op, TimeNotEqBinaryEvaluator), + BinaryOperator::Plus => unit_binary_ref(binary_pos(BINARY_TIME_BASE, TIME_PLUS_OFFSET)), + BinaryOperator::Minus => { + unit_binary_ref(binary_pos(BINARY_TIME_BASE, TIME_MINUS_OFFSET)) + } + BinaryOperator::Gt => unit_binary_ref(binary_pos(BINARY_TIME_BASE, TIME_GT_OFFSET)), + BinaryOperator::GtEq => { + unit_binary_ref(binary_pos(BINARY_TIME_BASE, TIME_GT_EQ_OFFSET)) + } + BinaryOperator::Lt => unit_binary_ref(binary_pos(BINARY_TIME_BASE, TIME_LT_OFFSET)), + BinaryOperator::LtEq => { + unit_binary_ref(binary_pos(BINARY_TIME_BASE, TIME_LT_EQ_OFFSET)) + } + BinaryOperator::Eq => unit_binary_ref(binary_pos(BINARY_TIME_BASE, TIME_EQ_OFFSET)), + BinaryOperator::NotEq => { + unit_binary_ref(binary_pos(BINARY_TIME_BASE, TIME_NOT_EQ_OFFSET)) + } _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), }, + #[cfg(not(feature = "time"))] + LogicalType::Time(_) => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), + #[cfg(feature = "time")] LogicalType::TimeStamp(_, _) => match op { - BinaryOperator::Gt => box_binary!(ty, op, Time64GtBinaryEvaluator), - BinaryOperator::GtEq => box_binary!(ty, op, Time64GtEqBinaryEvaluator), - BinaryOperator::Lt => box_binary!(ty, op, Time64LtBinaryEvaluator), - BinaryOperator::LtEq => box_binary!(ty, op, Time64LtEqBinaryEvaluator), - BinaryOperator::Eq => box_binary!(ty, op, Time64EqBinaryEvaluator), - BinaryOperator::NotEq => box_binary!(ty, op, Time64NotEqBinaryEvaluator), + BinaryOperator::Gt => { + unit_binary_ref(binary_pos(BINARY_TIME64_BASE, TIMESTAMP_GT_OFFSET)) + } + BinaryOperator::GtEq => { + unit_binary_ref(binary_pos(BINARY_TIME64_BASE, TIMESTAMP_GT_EQ_OFFSET)) + } + BinaryOperator::Lt => { + unit_binary_ref(binary_pos(BINARY_TIME64_BASE, TIMESTAMP_LT_OFFSET)) + } + BinaryOperator::LtEq => { + unit_binary_ref(binary_pos(BINARY_TIME64_BASE, TIMESTAMP_LT_EQ_OFFSET)) + } + BinaryOperator::Eq => { + unit_binary_ref(binary_pos(BINARY_TIME64_BASE, TIMESTAMP_EQ_OFFSET)) + } + BinaryOperator::NotEq => { + unit_binary_ref(binary_pos(BINARY_TIME64_BASE, TIMESTAMP_NOT_EQ_OFFSET)) + } _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), }, - LogicalType::Decimal(_, _) => numeric_binary_evaluator!(Decimal, op, ty), + #[cfg(not(feature = "time"))] + LogicalType::TimeStamp(_, _) => { + Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)) + } + #[cfg(feature = "decimal")] + LogicalType::Decimal(_, _) => numeric_binary_ref(BINARY_DECIMAL_BASE, ty, op), + #[cfg(not(feature = "decimal"))] + LogicalType::Decimal(_, _) => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), LogicalType::Boolean => match op { - BinaryOperator::And => box_binary!(ty, op, BooleanAndBinaryEvaluator), - BinaryOperator::Or => box_binary!(ty, op, BooleanOrBinaryEvaluator), - BinaryOperator::Eq => box_binary!(ty, op, BooleanEqBinaryEvaluator), - BinaryOperator::NotEq => box_binary!(ty, op, BooleanNotEqBinaryEvaluator), + BinaryOperator::And => { + unit_binary_ref(binary_pos(BINARY_BOOLEAN_BASE, BOOLEAN_AND_OFFSET)) + } + BinaryOperator::Or => { + unit_binary_ref(binary_pos(BINARY_BOOLEAN_BASE, BOOLEAN_OR_OFFSET)) + } + BinaryOperator::Eq => { + unit_binary_ref(binary_pos(BINARY_BOOLEAN_BASE, BOOLEAN_EQ_OFFSET)) + } + BinaryOperator::NotEq => { + unit_binary_ref(binary_pos(BINARY_BOOLEAN_BASE, BOOLEAN_NOT_EQ_OFFSET)) + } _ => Err(DatabaseError::UnsupportedBinaryOperator( LogicalType::Boolean, op, )), }, LogicalType::Varchar(_, _) | LogicalType::Char(_, _) => match op { - BinaryOperator::Gt => box_binary!(ty, op, Utf8GtBinaryEvaluator), - BinaryOperator::Lt => box_binary!(ty, op, Utf8LtBinaryEvaluator), - BinaryOperator::GtEq => box_binary!(ty, op, Utf8GtEqBinaryEvaluator), - BinaryOperator::LtEq => box_binary!(ty, op, Utf8LtEqBinaryEvaluator), - BinaryOperator::Eq => box_binary!(ty, op, Utf8EqBinaryEvaluator), - BinaryOperator::NotEq => box_binary!(ty, op, Utf8NotEqBinaryEvaluator), - BinaryOperator::StringConcat => box_binary!(ty, op, Utf8StringConcatBinaryEvaluator), - BinaryOperator::Like(escape_char) => { - box_binary!(ty, op, Utf8LikeBinaryEvaluator { escape_char }) - } - BinaryOperator::NotLike(escape_char) => { - box_binary!(ty, op, Utf8NotLikeBinaryEvaluator { escape_char }) + BinaryOperator::Gt => unit_binary_ref(binary_pos(BINARY_UTF8_BASE, UTF8_GT_OFFSET)), + BinaryOperator::Lt => unit_binary_ref(binary_pos(BINARY_UTF8_BASE, UTF8_LT_OFFSET)), + BinaryOperator::GtEq => { + unit_binary_ref(binary_pos(BINARY_UTF8_BASE, UTF8_GT_EQ_OFFSET)) + } + BinaryOperator::LtEq => { + unit_binary_ref(binary_pos(BINARY_UTF8_BASE, UTF8_LT_EQ_OFFSET)) } + BinaryOperator::Eq => unit_binary_ref(binary_pos(BINARY_UTF8_BASE, UTF8_EQ_OFFSET)), + BinaryOperator::NotEq => { + unit_binary_ref(binary_pos(BINARY_UTF8_BASE, UTF8_NOT_EQ_OFFSET)) + } + BinaryOperator::StringConcat => { + unit_binary_ref(binary_pos(BINARY_UTF8_BASE, UTF8_STRING_CONCAT_OFFSET)) + } + BinaryOperator::Like(escape_char) => Ok(BinaryEvaluatorRef::new( + binary_pos(BINARY_UTF8_BASE, UTF8_LIKE_OFFSET), + BinaryEvaluatorParams::Like { escape_char }, + )), + BinaryOperator::NotLike(escape_char) => Ok(BinaryEvaluatorRef::new( + binary_pos(BINARY_UTF8_BASE, UTF8_NOT_LIKE_OFFSET), + BinaryEvaluatorParams::Like { escape_char }, + )), _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), }, - LogicalType::SqlNull => box_binary!(ty, op, NullBinaryEvaluator), + LogicalType::SqlNull => unit_binary_ref(BINARY_SQL_NULL), LogicalType::Tuple(_) => match op { - BinaryOperator::Eq => box_binary!(ty, op, TupleEqBinaryEvaluator), - BinaryOperator::NotEq => box_binary!(ty, op, TupleNotEqBinaryEvaluator), - BinaryOperator::Gt => box_binary!(ty, op, TupleGtBinaryEvaluator), - BinaryOperator::GtEq => box_binary!(ty, op, TupleGtEqBinaryEvaluator), - BinaryOperator::Lt => box_binary!(ty, op, TupleLtBinaryEvaluator), - BinaryOperator::LtEq => box_binary!(ty, op, TupleLtEqBinaryEvaluator), + BinaryOperator::Eq => unit_binary_ref(binary_pos(BINARY_TUPLE_BASE, TUPLE_EQ_OFFSET)), + BinaryOperator::NotEq => { + unit_binary_ref(binary_pos(BINARY_TUPLE_BASE, TUPLE_NOT_EQ_OFFSET)) + } + BinaryOperator::Gt => unit_binary_ref(binary_pos(BINARY_TUPLE_BASE, TUPLE_GT_OFFSET)), + BinaryOperator::GtEq => { + unit_binary_ref(binary_pos(BINARY_TUPLE_BASE, TUPLE_GT_EQ_OFFSET)) + } + BinaryOperator::Lt => unit_binary_ref(binary_pos(BINARY_TUPLE_BASE, TUPLE_LT_OFFSET)), + BinaryOperator::LtEq => { + unit_binary_ref(binary_pos(BINARY_TUPLE_BASE, TUPLE_LT_EQ_OFFSET)) + } _ => Err(DatabaseError::UnsupportedBinaryOperator(ty.clone(), op)), }, } } +macro_rules! eval_numeric_binary { + ($pos:expr, $base:expr, $value_type:ident, $left:expr, $right:expr) => { + paste! { + match $pos - $base { + NUMERIC_PLUS_OFFSET => [<$value_type:snake _plus_binary_eval>]($left, $right), + NUMERIC_MINUS_OFFSET => [<$value_type:snake _minus_binary_eval>]($left, $right), + NUMERIC_MULTIPLY_OFFSET => [<$value_type:snake _multiply_binary_eval>]($left, $right), + NUMERIC_DIVIDE_OFFSET => [<$value_type:snake _divide_binary_eval>]($left, $right), + NUMERIC_GT_OFFSET => [<$value_type:snake _gt_binary_eval>]($left, $right), + NUMERIC_GT_EQ_OFFSET => [<$value_type:snake _gt_eq_binary_eval>]($left, $right), + NUMERIC_LT_OFFSET => [<$value_type:snake _lt_binary_eval>]($left, $right), + NUMERIC_LT_EQ_OFFSET => [<$value_type:snake _lt_eq_binary_eval>]($left, $right), + NUMERIC_EQ_OFFSET => [<$value_type:snake _eq_binary_eval>]($left, $right), + NUMERIC_NOT_EQ_OFFSET => [<$value_type:snake _not_eq_binary_eval>]($left, $right), + NUMERIC_MODULO_OFFSET => [<$value_type:snake _mod_binary_eval>]($left, $right), + _ => unreachable!(), + } + } + }; +} + +pub(crate) fn eval_binary( + pos: u16, + params: &BinaryEvaluatorParams, + left: &crate::types::value::DataValue, + right: &crate::types::value::DataValue, +) -> Result { + match pos { + BINARY_INT8_BASE..BINARY_INT16_BASE => { + eval_numeric_binary!(pos, BINARY_INT8_BASE, Int8, left, right) + } + BINARY_INT16_BASE..BINARY_INT32_BASE => { + eval_numeric_binary!(pos, BINARY_INT16_BASE, Int16, left, right) + } + BINARY_INT32_BASE..BINARY_INT64_BASE => { + eval_numeric_binary!(pos, BINARY_INT32_BASE, Int32, left, right) + } + BINARY_INT64_BASE..BINARY_UINT8_BASE => { + eval_numeric_binary!(pos, BINARY_INT64_BASE, Int64, left, right) + } + BINARY_UINT8_BASE..BINARY_UINT16_BASE => { + eval_numeric_binary!(pos, BINARY_UINT8_BASE, Uint8, left, right) + } + BINARY_UINT16_BASE..BINARY_UINT32_BASE => { + eval_numeric_binary!(pos, BINARY_UINT16_BASE, Uint16, left, right) + } + BINARY_UINT32_BASE..BINARY_UINT64_BASE => { + eval_numeric_binary!(pos, BINARY_UINT32_BASE, Uint32, left, right) + } + BINARY_UINT64_BASE..BINARY_FLOAT32_BASE => { + eval_numeric_binary!(pos, BINARY_UINT64_BASE, Uint64, left, right) + } + BINARY_FLOAT32_BASE..BINARY_FLOAT64_BASE => { + eval_numeric_binary!(pos, BINARY_FLOAT32_BASE, Float32, left, right) + } + BINARY_FLOAT64_BASE..BINARY_DATE_BASE => { + eval_numeric_binary!(pos, BINARY_FLOAT64_BASE, Float64, left, right) + } + #[cfg(feature = "time")] + BINARY_DATE_BASE..BINARY_DATETIME_BASE => { + eval_numeric_binary!(pos, BINARY_DATE_BASE, Date, left, right) + } + #[cfg(not(feature = "time"))] + BINARY_DATE_BASE..BINARY_DATETIME_BASE => Err(DatabaseError::UnsupportedStmt( + "time types require the `time` feature".to_string(), + )), + #[cfg(feature = "time")] + BINARY_DATETIME_BASE..BINARY_DECIMAL_BASE => { + eval_numeric_binary!(pos, BINARY_DATETIME_BASE, DateTime, left, right) + } + #[cfg(not(feature = "time"))] + BINARY_DATETIME_BASE..BINARY_DECIMAL_BASE => Err(DatabaseError::UnsupportedStmt( + "time types require the `time` feature".to_string(), + )), + #[cfg(feature = "decimal")] + BINARY_DECIMAL_BASE..BINARY_TIME_BASE => { + eval_numeric_binary!(pos, BINARY_DECIMAL_BASE, Decimal, left, right) + } + #[cfg(not(feature = "decimal"))] + BINARY_DECIMAL_BASE..BINARY_TIME_BASE => Err(DatabaseError::UnsupportedStmt( + "DECIMAL requires the `decimal` feature".to_string(), + )), + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME_BASE, TIME_PLUS_OFFSET) => { + time_plus_binary_eval(left, right) + } + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME_BASE, TIME_MINUS_OFFSET) => { + time_minus_binary_eval(left, right) + } + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME_BASE, TIME_GT_OFFSET) => time_gt_binary_eval(left, right), + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME_BASE, TIME_GT_EQ_OFFSET) => { + time_gt_eq_binary_eval(left, right) + } + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME_BASE, TIME_LT_OFFSET) => time_lt_binary_eval(left, right), + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME_BASE, TIME_LT_EQ_OFFSET) => { + time_lt_eq_binary_eval(left, right) + } + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME_BASE, TIME_EQ_OFFSET) => time_eq_binary_eval(left, right), + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME_BASE, TIME_NOT_EQ_OFFSET) => { + time_not_eq_binary_eval(left, right) + } + #[cfg(not(feature = "time"))] + BINARY_TIME_BASE..BINARY_TIME64_BASE => Err(DatabaseError::UnsupportedStmt( + "time types require the `time` feature".to_string(), + )), + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME64_BASE, TIMESTAMP_GT_OFFSET) => { + time64_gt_binary_eval(left, right) + } + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME64_BASE, TIMESTAMP_GT_EQ_OFFSET) => { + time64_gt_eq_binary_eval(left, right) + } + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME64_BASE, TIMESTAMP_LT_OFFSET) => { + time64_lt_binary_eval(left, right) + } + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME64_BASE, TIMESTAMP_LT_EQ_OFFSET) => { + time64_lt_eq_binary_eval(left, right) + } + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME64_BASE, TIMESTAMP_EQ_OFFSET) => { + time64_eq_binary_eval(left, right) + } + #[cfg(feature = "time")] + x if x == binary_pos(BINARY_TIME64_BASE, TIMESTAMP_NOT_EQ_OFFSET) => { + time64_not_eq_binary_eval(left, right) + } + #[cfg(not(feature = "time"))] + BINARY_TIME64_BASE..BINARY_BOOLEAN_BASE => Err(DatabaseError::UnsupportedStmt( + "time types require the `time` feature".to_string(), + )), + x if x == binary_pos(BINARY_BOOLEAN_BASE, BOOLEAN_AND_OFFSET) => { + boolean_and_binary_eval(left, right) + } + x if x == binary_pos(BINARY_BOOLEAN_BASE, BOOLEAN_OR_OFFSET) => { + boolean_or_binary_eval(left, right) + } + x if x == binary_pos(BINARY_BOOLEAN_BASE, BOOLEAN_EQ_OFFSET) => { + boolean_eq_binary_eval(left, right) + } + x if x == binary_pos(BINARY_BOOLEAN_BASE, BOOLEAN_NOT_EQ_OFFSET) => { + boolean_not_eq_binary_eval(left, right) + } + x if x == binary_pos(BINARY_UTF8_BASE, UTF8_GT_OFFSET) => utf8_gt_binary_eval(left, right), + x if x == binary_pos(BINARY_UTF8_BASE, UTF8_LT_OFFSET) => utf8_lt_binary_eval(left, right), + x if x == binary_pos(BINARY_UTF8_BASE, UTF8_GT_EQ_OFFSET) => { + utf8_gt_eq_binary_eval(left, right) + } + x if x == binary_pos(BINARY_UTF8_BASE, UTF8_LT_EQ_OFFSET) => { + utf8_lt_eq_binary_eval(left, right) + } + x if x == binary_pos(BINARY_UTF8_BASE, UTF8_EQ_OFFSET) => utf8_eq_binary_eval(left, right), + x if x == binary_pos(BINARY_UTF8_BASE, UTF8_NOT_EQ_OFFSET) => { + utf8_not_eq_binary_eval(left, right) + } + x if x == binary_pos(BINARY_UTF8_BASE, UTF8_STRING_CONCAT_OFFSET) => { + utf8_string_concat_binary_eval(left, right) + } + x if x == binary_pos(BINARY_UTF8_BASE, UTF8_LIKE_OFFSET) => { + let BinaryEvaluatorParams::Like { escape_char } = params else { + unreachable!() + }; + utf8_like_binary_eval(*escape_char, left, right) + } + x if x == binary_pos(BINARY_UTF8_BASE, UTF8_NOT_LIKE_OFFSET) => { + let BinaryEvaluatorParams::Like { escape_char } = params else { + unreachable!() + }; + utf8_not_like_binary_eval(*escape_char, left, right) + } + BINARY_SQL_NULL => null_binary_eval(left, right), + x if x == binary_pos(BINARY_TUPLE_BASE, TUPLE_EQ_OFFSET) => { + tuple_eq_binary_eval(left, right) + } + x if x == binary_pos(BINARY_TUPLE_BASE, TUPLE_NOT_EQ_OFFSET) => { + tuple_not_eq_binary_eval(left, right) + } + x if x == binary_pos(BINARY_TUPLE_BASE, TUPLE_GT_OFFSET) => { + tuple_gt_binary_eval(left, right) + } + x if x == binary_pos(BINARY_TUPLE_BASE, TUPLE_GT_EQ_OFFSET) => { + tuple_gt_eq_binary_eval(left, right) + } + x if x == binary_pos(BINARY_TUPLE_BASE, TUPLE_LT_OFFSET) => { + tuple_lt_binary_eval(left, right) + } + x if x == binary_pos(BINARY_TUPLE_BASE, TUPLE_LT_EQ_OFFSET) => { + tuple_lt_eq_binary_eval(left, right) + } + _ => unreachable!("unknown binary evaluator position {pos}"), + } +} + #[macro_export] macro_rules! numeric_binary_evaluator_definition { ($value_type:ident, $compute_type:path) => { paste::paste! { - #[derive(Debug)] - pub struct [<$value_type PlusBinaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type MinusBinaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type MultiplyBinaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type DivideBinaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type GtBinaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type GtEqBinaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type LtBinaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type LtEqBinaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type EqBinaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type NotEqBinaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type ModBinaryEvaluator>]; impl $crate::types::evaluator::BinaryEvaluator for [<$value_type PlusBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_add(*v2).ok_or($crate::errors::DatabaseError::OverFlow)?), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } - } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type MinusBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_sub(*v2).ok_or($crate::errors::DatabaseError::OverFlow)?), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } - } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type MultiplyBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_mul(*v2).ok_or($crate::errors::DatabaseError::OverFlow)?), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } - } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type DivideBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Float64(ordered_float::OrderedFloat(*v1 as f64 / *v2 as f64)), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } - } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type GtBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 > v2), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } - } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type GtEqBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 >= v2), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } - } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type LtBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 < v2), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } - } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type LtEqBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 <= v2), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } - } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type EqBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 == v2), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } - } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type NotEqBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 != v2), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } - } impl $crate::types::evaluator::BinaryEvaluator for [<$value_type ModBinaryEvaluator>] { - fn binary_eval( - &self, - left: &$crate::types::value::DataValue, - right: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - Ok(match (left, right) { - ($compute_type(v1), $compute_type(v2)) => $compute_type(*v1 % *v2), - ($compute_type(_), $crate::types::value::DataValue::Null) - | ($crate::types::value::DataValue::Null, $compute_type(_)) - | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - }) - } + pub fn [<$value_type:snake _plus_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_add(*v2).ok_or($crate::errors::DatabaseError::OverFlow)?), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + pub fn [<$value_type:snake _minus_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_sub(*v2).ok_or($crate::errors::DatabaseError::OverFlow)?), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + pub fn [<$value_type:snake _multiply_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $compute_type(v1.checked_mul(*v2).ok_or($crate::errors::DatabaseError::OverFlow)?), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + pub fn [<$value_type:snake _divide_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Float64(ordered_float::OrderedFloat(*v1 as f64 / *v2 as f64)), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + pub fn [<$value_type:snake _gt_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 > v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + pub fn [<$value_type:snake _gt_eq_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 >= v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + pub fn [<$value_type:snake _lt_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 < v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + pub fn [<$value_type:snake _lt_eq_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 <= v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + pub fn [<$value_type:snake _eq_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 == v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + pub fn [<$value_type:snake _not_eq_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $crate::types::value::DataValue::Boolean(v1 != v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) + } + pub fn [<$value_type:snake _mod_binary_eval>]( + left: &$crate::types::value::DataValue, + right: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + Ok(match (left, right) { + ($compute_type(v1), $compute_type(v2)) => $compute_type(*v1 % *v2), + ($compute_type(_), $crate::types::value::DataValue::Null) + | ($crate::types::value::DataValue::Null, $compute_type(_)) + | ($crate::types::value::DataValue::Null, $crate::types::value::DataValue::Null) => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, + }) } } }; @@ -337,34 +658,58 @@ macro_rules! numeric_binary_evaluator_definition { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use super::binary_create; + use super::*; use crate::errors::DatabaseError; use crate::expression::BinaryOperator; use crate::serdes::{ReferenceSerialization, ReferenceTables}; use crate::storage::rocksdb::RocksTransaction; - use crate::types::evaluator::BinaryEvaluatorBox; + use crate::types::evaluator::BinaryEvaluatorRef; use crate::types::LogicalType; use std::borrow::Cow; use std::io::{Cursor, Seek, SeekFrom}; - fn create(ty: LogicalType, op: BinaryOperator) -> Result { + fn create(ty: LogicalType, op: BinaryOperator) -> Result { binary_create(Cow::Owned(ty), op) } + #[test] + fn test_binary_evaluator_positions_are_stable() -> Result<(), DatabaseError> { + assert_eq!( + create(LogicalType::Integer, BinaryOperator::Plus)?.pos, + BINARY_INT32_BASE + ); + assert_eq!( + create(LogicalType::Boolean, BinaryOperator::NotEq)?.pos, + binary_pos(BINARY_BOOLEAN_BASE, BOOLEAN_NOT_EQ_OFFSET) + ); + assert_eq!( + create( + LogicalType::Varchar(None, crate::types::CharLengthUnits::Characters), + BinaryOperator::StringConcat + )? + .pos, + binary_pos(BINARY_UTF8_BASE, UTF8_STRING_CONCAT_OFFSET) + ); + + Ok(()) + } + #[test] fn test_binary_evaluator_serialization() -> Result<(), DatabaseError> { let evaluator = create(LogicalType::Boolean, BinaryOperator::NotEq)?; let mut cursor = Cursor::new(Vec::new()); let mut reference_tables = ReferenceTables::new(); + let mut arena = crate::planner::TableArena::default(); - evaluator.encode(&mut cursor, false, &mut reference_tables)?; + evaluator.encode(&mut cursor, false, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; assert_eq!( - BinaryEvaluatorBox::decode::( + BinaryEvaluatorRef::decode::( &mut cursor, None, - &reference_tables + &reference_tables, + &mut arena, )?, evaluator ); diff --git a/src/types/evaluator/boolean.rs b/src/types/evaluator/boolean.rs index 9f54dc6f..8a0df8c8 100644 --- a/src/types/evaluator/boolean.rs +++ b/src/types/evaluator/boolean.rs @@ -15,148 +15,131 @@ use crate::errors::DatabaseError; use crate::types::evaluator::cast::{to_char, to_varchar}; use crate::types::evaluator::DataValue; -use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use std::hint; - -#[derive(Debug)] -pub struct BooleanNotUnaryEvaluator; -#[derive(Debug)] -pub struct BooleanAndBinaryEvaluator; -#[derive(Debug)] -pub struct BooleanOrBinaryEvaluator; -#[derive(Debug)] -pub struct BooleanEqBinaryEvaluator; -#[derive(Debug)] -pub struct BooleanNotEqBinaryEvaluator; -impl UnaryEvaluator for BooleanNotUnaryEvaluator { - fn unary_eval(&self, value: &DataValue) -> DataValue { - match value { - DataValue::Boolean(value) => DataValue::Boolean(!value), - DataValue::Null => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - } +pub fn boolean_not_unary_eval(value: &DataValue) -> DataValue { + match value { + DataValue::Boolean(value) => DataValue::Boolean(!value), + DataValue::Null => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, } } -impl BinaryEvaluator for BooleanAndBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Boolean(v1), DataValue::Boolean(v2)) => DataValue::Boolean(*v1 && *v2), - (DataValue::Boolean(false), DataValue::Null) - | (DataValue::Null, DataValue::Boolean(false)) => DataValue::Boolean(false), - (DataValue::Null, DataValue::Null) - | (DataValue::Boolean(true), DataValue::Null) - | (DataValue::Null, DataValue::Boolean(true)) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } +pub fn boolean_and_binary_eval( + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match (left, right) { + (DataValue::Boolean(v1), DataValue::Boolean(v2)) => DataValue::Boolean(*v1 && *v2), + (DataValue::Boolean(false), DataValue::Null) + | (DataValue::Null, DataValue::Boolean(false)) => DataValue::Boolean(false), + (DataValue::Null, DataValue::Null) + | (DataValue::Boolean(true), DataValue::Null) + | (DataValue::Null, DataValue::Boolean(true)) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } -impl BinaryEvaluator for BooleanOrBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Boolean(v1), DataValue::Boolean(v2)) => DataValue::Boolean(*v1 || *v2), - (DataValue::Boolean(true), DataValue::Null) - | (DataValue::Null, DataValue::Boolean(true)) => DataValue::Boolean(true), - (DataValue::Null, DataValue::Null) - | (DataValue::Boolean(false), DataValue::Null) - | (DataValue::Null, DataValue::Boolean(false)) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } +pub fn boolean_or_binary_eval( + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match (left, right) { + (DataValue::Boolean(v1), DataValue::Boolean(v2)) => DataValue::Boolean(*v1 || *v2), + (DataValue::Boolean(true), DataValue::Null) + | (DataValue::Null, DataValue::Boolean(true)) => DataValue::Boolean(true), + (DataValue::Null, DataValue::Null) + | (DataValue::Boolean(false), DataValue::Null) + | (DataValue::Null, DataValue::Boolean(false)) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } -impl BinaryEvaluator for BooleanEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Boolean(v1), DataValue::Boolean(v2)) => DataValue::Boolean(*v1 == *v2), - (DataValue::Null, DataValue::Boolean(_)) - | (DataValue::Boolean(_), DataValue::Null) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } +pub fn boolean_eq_binary_eval( + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match (left, right) { + (DataValue::Boolean(v1), DataValue::Boolean(v2)) => DataValue::Boolean(*v1 == *v2), + (DataValue::Null, DataValue::Boolean(_)) + | (DataValue::Boolean(_), DataValue::Null) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } -crate::define_cast_evaluator!(BooleanToTinyintCastEvaluator, DataValue::Boolean(value) => { +crate::define_cast_evaluator!(boolean_to_tinyint_cast_eval, DataValue::Boolean(value) => { Ok(DataValue::Int8(if *value { 1 } else { 0 })) }); -crate::define_cast_evaluator!(BooleanToUTinyintCastEvaluator, DataValue::Boolean(value) => { +crate::define_cast_evaluator!(boolean_to_utinyint_cast_eval, DataValue::Boolean(value) => { Ok(DataValue::UInt8(if *value { 1 } else { 0 })) }); -crate::define_cast_evaluator!(BooleanToSmallintCastEvaluator, DataValue::Boolean(value) => { +crate::define_cast_evaluator!(boolean_to_smallint_cast_eval, DataValue::Boolean(value) => { Ok(DataValue::Int16(if *value { 1 } else { 0 })) }); -crate::define_cast_evaluator!(BooleanToUSmallintCastEvaluator, DataValue::Boolean(value) => { +crate::define_cast_evaluator!(boolean_to_usmallint_cast_eval, DataValue::Boolean(value) => { Ok(DataValue::UInt16(if *value { 1 } else { 0 })) }); -crate::define_cast_evaluator!(BooleanToIntegerCastEvaluator, DataValue::Boolean(value) => { +crate::define_cast_evaluator!(boolean_to_integer_cast_eval, DataValue::Boolean(value) => { Ok(DataValue::Int32(if *value { 1 } else { 0 })) }); -crate::define_cast_evaluator!(BooleanToUIntegerCastEvaluator, DataValue::Boolean(value) => { +crate::define_cast_evaluator!(boolean_to_uinteger_cast_eval, DataValue::Boolean(value) => { Ok(DataValue::UInt32(if *value { 1 } else { 0 })) }); -crate::define_cast_evaluator!(BooleanToBigintCastEvaluator, DataValue::Boolean(value) => { +crate::define_cast_evaluator!(boolean_to_bigint_cast_eval, DataValue::Boolean(value) => { Ok(DataValue::Int64(if *value { 1 } else { 0 })) }); -crate::define_cast_evaluator!(BooleanToUBigintCastEvaluator, DataValue::Boolean(value) => { +crate::define_cast_evaluator!(boolean_to_ubigint_cast_eval, DataValue::Boolean(value) => { Ok(DataValue::UInt64(if *value { 1 } else { 0 })) }); -crate::define_cast_evaluator!(BooleanToFloatCastEvaluator, DataValue::Boolean(value) => { +crate::define_cast_evaluator!(boolean_to_float_cast_eval, DataValue::Boolean(value) => { Ok(DataValue::Float32(OrderedFloat(if *value { 1.0 } else { 0.0 }))) }); -crate::define_cast_evaluator!(BooleanToDoubleCastEvaluator, DataValue::Boolean(value) => { +crate::define_cast_evaluator!(boolean_to_double_cast_eval, DataValue::Boolean(value) => { Ok(DataValue::Float64(OrderedFloat(if *value { 1.0 } else { 0.0 }))) }); crate::define_cast_evaluator!( - BooleanToCharCastEvaluator { + boolean_to_char_cast_eval { len: u32, unit: CharLengthUnits }, DataValue::Boolean(value) => |this| to_char(value.to_string(), this.len, this.unit) ); crate::define_cast_evaluator!( - BooleanToVarcharCastEvaluator { + boolean_to_varchar_cast_eval { len: Option, unit: CharLengthUnits }, DataValue::Boolean(value) => |this| to_varchar(value.to_string(), this.len, this.unit) ); -impl BinaryEvaluator for BooleanNotEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Boolean(v1), DataValue::Boolean(v2)) => DataValue::Boolean(*v1 != *v2), - (DataValue::Null, DataValue::Boolean(_)) - | (DataValue::Boolean(_), DataValue::Null) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } +pub fn boolean_not_eq_binary_eval( + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match (left, right) { + (DataValue::Boolean(v1), DataValue::Boolean(v2)) => DataValue::Boolean(*v1 != *v2), + (DataValue::Null, DataValue::Boolean(_)) + | (DataValue::Boolean(_), DataValue::Null) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; use crate::types::value::Utf8Type; #[test] fn test_boolean_binary_evaluators() { assert_eq!( - BooleanAndBinaryEvaluator - .binary_eval(&DataValue::Boolean(true), &DataValue::Boolean(true)) - .unwrap(), + boolean_and_binary_eval(&DataValue::Boolean(true), &DataValue::Boolean(true)).unwrap(), DataValue::Boolean(true) ); assert_eq!( - BooleanAndBinaryEvaluator - .binary_eval(&DataValue::Boolean(false), &DataValue::Null) - .unwrap(), + boolean_and_binary_eval(&DataValue::Boolean(false), &DataValue::Null).unwrap(), DataValue::Boolean(false) ); assert_eq!( - BooleanOrBinaryEvaluator - .binary_eval(&DataValue::Boolean(false), &DataValue::Boolean(true)) - .unwrap(), + boolean_or_binary_eval(&DataValue::Boolean(false), &DataValue::Boolean(true)).unwrap(), DataValue::Boolean(true) ); } @@ -166,52 +149,47 @@ mod test { let value = DataValue::Boolean(true); assert_eq!( - BooleanToTinyintCastEvaluator.eval_cast(&value).unwrap(), + boolean_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - BooleanToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + boolean_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - BooleanToSmallintCastEvaluator.eval_cast(&value).unwrap(), + boolean_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - BooleanToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + boolean_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - BooleanToIntegerCastEvaluator.eval_cast(&value).unwrap(), + boolean_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - BooleanToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + boolean_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - BooleanToBigintCastEvaluator.eval_cast(&value).unwrap(), + boolean_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - BooleanToUBigintCastEvaluator.eval_cast(&value).unwrap(), + boolean_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - BooleanToFloatCastEvaluator.eval_cast(&value).unwrap(), + boolean_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(1.0)) ); assert_eq!( - BooleanToDoubleCastEvaluator.eval_cast(&value).unwrap(), + boolean_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(1.0)) ); assert_eq!( - BooleanToCharCastEvaluator { - len: 4, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + boolean_to_char_cast_eval(4, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "true".to_string(), ty: Utf8Type::Fixed(4), @@ -219,12 +197,7 @@ mod test { } ); assert_eq!( - BooleanToVarcharCastEvaluator { - len: Some(4), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + boolean_to_varchar_cast_eval(Some(4), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "true".to_string(), ty: Utf8Type::Variable(Some(4)), @@ -232,17 +205,15 @@ mod test { } ); assert_eq!( - BooleanToDoubleCastEvaluator - .eval_cast(&DataValue::Boolean(false)) - .unwrap(), + boolean_to_double_cast_eval(&DataValue::Boolean(false)).unwrap(), DataValue::Float64(OrderedFloat(0.0)) ); assert_eq!( - BooleanToVarcharCastEvaluator { - len: None, - unit: CharLengthUnits::Characters, - } - .eval_cast(&DataValue::Boolean(false)) + boolean_to_varchar_cast_eval( + None, + CharLengthUnits::Characters, + &DataValue::Boolean(false) + ) .unwrap(), DataValue::Utf8 { value: "false".to_string(), diff --git a/src/types/evaluator/cast.rs b/src/types/evaluator/cast.rs index 1b8a8485..c4ffff95 100644 --- a/src/types/evaluator/cast.rs +++ b/src/types/evaluator/cast.rs @@ -14,8 +14,11 @@ use crate::errors::DatabaseError; use crate::types::evaluator::boolean::*; +#[cfg(feature = "time")] use crate::types::evaluator::date::*; +#[cfg(feature = "time")] use crate::types::evaluator::datetime::*; +#[cfg(feature = "decimal")] use crate::types::evaluator::decimal::*; use crate::types::evaluator::float32::*; use crate::types::evaluator::float64::*; @@ -23,22 +26,23 @@ use crate::types::evaluator::int16::*; use crate::types::evaluator::int32::*; use crate::types::evaluator::int64::*; use crate::types::evaluator::int8::*; -use crate::types::evaluator::null::{NullCastEvaluator, ToSqlNullCastEvaluator}; +use crate::types::evaluator::null::{null_cast_eval, to_sql_null_cast_eval}; +#[cfg(feature = "time")] use crate::types::evaluator::time32::*; +#[cfg(feature = "time")] use crate::types::evaluator::time64::*; -use crate::types::evaluator::tuple::TupleCastEvaluator; +use crate::types::evaluator::tuple::eval_tuple_cast; use crate::types::evaluator::uint16::*; use crate::types::evaluator::uint32::*; use crate::types::evaluator::uint64::*; use crate::types::evaluator::uint8::*; use crate::types::evaluator::utf8::*; -use crate::types::evaluator::{CastEvaluator, CastEvaluatorBox}; +use crate::types::evaluator::{CastEvaluatorParams, CastEvaluatorRef}; use crate::types::value::{DataValue, Utf8Type}; use crate::types::CharLengthUnits; use crate::types::LogicalType; use paste::paste; use std::borrow::Cow; -use std::sync::Arc; pub(crate) fn cast_fail(from: LogicalType, to: LogicalType) -> DatabaseError { DatabaseError::CastFail { @@ -48,6 +52,79 @@ pub(crate) fn cast_fail(from: LogicalType, to: LogicalType) -> DatabaseError { } } +const CAST_TYPE_STRIDE: u16 = 32; +const CAST_SQL_NULL: u16 = 0; +const CAST_BOOLEAN: u16 = 1; +const CAST_TINYINT: u16 = 2; +const CAST_UTINYINT: u16 = 3; +const CAST_SMALLINT: u16 = 4; +const CAST_USMALLINT: u16 = 5; +const CAST_INTEGER: u16 = 6; +const CAST_UINTEGER: u16 = 7; +const CAST_BIGINT: u16 = 8; +const CAST_UBIGINT: u16 = 9; +const CAST_FLOAT: u16 = 10; +const CAST_DOUBLE: u16 = 11; +const CAST_CHAR: u16 = 12; +const CAST_VARCHAR: u16 = 13; +const CAST_DATE: u16 = 14; +const CAST_DATETIME: u16 = 15; +const CAST_TIME: u16 = 16; +const CAST_TIMESTAMP: u16 = 17; +#[cfg(feature = "decimal")] +const CAST_DECIMAL: u16 = 18; +const CAST_TUPLE: u16 = 19; + +// Cast positions are serialized ABI. Type codes above must never be reordered +// or reused; new cast families should append a new code and keep old positions. +fn cast_type_code(ty: &LogicalType) -> u16 { + match ty { + LogicalType::SqlNull => CAST_SQL_NULL, + LogicalType::Boolean => CAST_BOOLEAN, + LogicalType::Tinyint => CAST_TINYINT, + LogicalType::UTinyint => CAST_UTINYINT, + LogicalType::Smallint => CAST_SMALLINT, + LogicalType::USmallint => CAST_USMALLINT, + LogicalType::Integer => CAST_INTEGER, + LogicalType::UInteger => CAST_UINTEGER, + LogicalType::Bigint => CAST_BIGINT, + LogicalType::UBigint => CAST_UBIGINT, + LogicalType::Float => CAST_FLOAT, + LogicalType::Double => CAST_DOUBLE, + LogicalType::Char(_, _) => CAST_CHAR, + LogicalType::Varchar(_, _) => CAST_VARCHAR, + LogicalType::Date => CAST_DATE, + LogicalType::DateTime => CAST_DATETIME, + LogicalType::Time(_) => CAST_TIME, + LogicalType::TimeStamp(_, _) => CAST_TIMESTAMP, + #[cfg(feature = "decimal")] + LogicalType::Decimal(_, _) => CAST_DECIMAL, + #[cfg(not(feature = "decimal"))] + LogicalType::Decimal(_, _) => unreachable!("DECIMAL requires the `decimal` feature"), + LogicalType::Tuple(_) => CAST_TUPLE, + } +} + +fn cast_pos(from: &LogicalType, to: &LogicalType) -> u16 { + cast_type_code(from) * CAST_TYPE_STRIDE + cast_type_code(to) +} + +#[cfg(not(feature = "time"))] +fn is_chrono_type(ty: &LogicalType) -> bool { + matches!( + ty, + LogicalType::Date + | LogicalType::DateTime + | LogicalType::Time(_) + | LogicalType::TimeStamp(_, _) + ) +} + +#[cfg(not(feature = "time"))] +fn is_chrono_type_code(code: u16) -> bool { + matches!(code, CAST_DATE | CAST_DATETIME | CAST_TIME | CAST_TIMESTAMP) +} + pub(crate) fn to_char( value: String, len: u32, @@ -110,6 +187,7 @@ macro_rules! float_to_int_cast { }}; } +#[cfg(feature = "decimal")] #[macro_export] macro_rules! decimal_to_int_cast { ($decimal:expr, $int_type:ty) => {{ @@ -139,95 +217,69 @@ macro_rules! decimal_to_int_cast { #[macro_export] macro_rules! define_cast_evaluator { ($name:ident, $pattern:pat => $body:block) => { - #[derive(Debug)] - pub struct $name; - impl $crate::types::evaluator::CastEvaluator for $name { - fn eval_cast( - &self, - value: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - match value { - $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), - $pattern => $body, - _ => unsafe { std::hint::unreachable_unchecked() }, - } + pub fn $name( + value: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + match value { + $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), + $pattern => $body, + _ => unsafe { std::hint::unreachable_unchecked() }, } } }; ($name:ident, $pattern:pat => |$this:ident| $body:expr) => { - #[derive(Debug)] - pub struct $name; - impl $crate::types::evaluator::CastEvaluator for $name { - fn eval_cast( - &self, - value: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - match value { - $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), - $pattern => { - let $this = self; - $body - } - _ => unsafe { std::hint::unreachable_unchecked() }, + pub fn $name( + value: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + match value { + $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), + $pattern => { + let $this = (); + $body } + _ => unsafe { std::hint::unreachable_unchecked() }, } } }; ($name:ident, $pattern:pat => |$this:ident| $body:block) => { - #[derive(Debug)] - pub struct $name; - impl $crate::types::evaluator::CastEvaluator for $name { - fn eval_cast( - &self, - value: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - match value { - $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), - $pattern => { - let $this = self; - $body - } - _ => unsafe { std::hint::unreachable_unchecked() }, + pub fn $name( + value: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + match value { + $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), + $pattern => { + let $this = (); + $body } + _ => unsafe { std::hint::unreachable_unchecked() }, } } }; ($name:ident { $($field:ident : $field_ty:ty),+ $(,)? }, $pattern:pat => |$this:ident| $body:expr) => { - #[derive(Debug)] - pub struct $name { - $(pub $field: $field_ty),+ - } - impl $crate::types::evaluator::CastEvaluator for $name { - fn eval_cast( - &self, - value: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - match value { - $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), - $pattern => { - let $this = self; - $body - } - _ => unsafe { std::hint::unreachable_unchecked() }, - } + pub fn $name( + $($field: $field_ty,)+ + value: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + struct This { + $($field: $field_ty),+ + } + let $this = This { $($field),+ }; + match value { + $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), + $pattern => $body, + _ => unsafe { std::hint::unreachable_unchecked() }, } } }; ($name:ident { $($field:ident : $field_ty:ty),+ $(,)? }, $pattern:pat => $body:block) => { - #[derive(Debug)] - pub struct $name { - $(pub $field: $field_ty),+ - } - impl $crate::types::evaluator::CastEvaluator for $name { - fn eval_cast( - &self, - value: &$crate::types::value::DataValue, - ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { - match value { - $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), - $pattern => $body, - _ => unsafe { std::hint::unreachable_unchecked() }, - } + pub fn $name( + $($field: $field_ty,)+ + value: &$crate::types::value::DataValue, + ) -> Result<$crate::types::value::DataValue, $crate::errors::DatabaseError> { + match value { + $crate::types::value::DataValue::Null => Ok($crate::types::value::DataValue::Null), + $pattern => $body, + _ => unsafe { std::hint::unreachable_unchecked() }, } } }; @@ -237,59 +289,60 @@ macro_rules! define_cast_evaluator { macro_rules! define_integer_cast_evaluators { ($prefix:ident, $variant:ident, $src_ty:ty, $from_ty:expr) => { paste::paste! { - $crate::define_cast_evaluator!([<$prefix ToBooleanCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_boolean_cast_eval>], $crate::types::value::DataValue::$variant(value) => { $crate::numeric_to_boolean_cast!(*value, $from_ty) }); - $crate::define_cast_evaluator!([<$prefix ToTinyintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_tinyint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Int8(i8::try_from(*value)?)) }); - $crate::define_cast_evaluator!([<$prefix ToUTinyintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_utinyint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::UInt8(u8::try_from(*value)?)) }); - $crate::define_cast_evaluator!([<$prefix ToSmallintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_smallint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Int16(i16::try_from(*value)?)) }); - $crate::define_cast_evaluator!([<$prefix ToUSmallintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_usmallint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::UInt16(u16::try_from(*value)?)) }); - $crate::define_cast_evaluator!([<$prefix ToIntegerCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_integer_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Int32(i32::try_from(*value)?)) }); - $crate::define_cast_evaluator!([<$prefix ToUIntegerCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_uinteger_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::UInt32(u32::try_from(*value)?)) }); - $crate::define_cast_evaluator!([<$prefix ToBigintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_bigint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Int64(i64::try_from(*value)?)) }); - $crate::define_cast_evaluator!([<$prefix ToUBigintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_ubigint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::UInt64(u64::try_from(*value)?)) }); - $crate::define_cast_evaluator!([<$prefix ToFloatCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_float_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Float32(ordered_float::OrderedFloat(*value as f32))) }); - $crate::define_cast_evaluator!([<$prefix ToDoubleCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_double_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Float64(ordered_float::OrderedFloat(*value as f64))) }); $crate::define_cast_evaluator!( - [<$prefix ToCharCastEvaluator>] { + [<$prefix:snake _to_char_cast_eval>] { len: u32, - unit: crate::types::CharLengthUnits + unit: $crate::types::CharLengthUnits }, $crate::types::value::DataValue::$variant(value) => |this| { $crate::types::evaluator::cast::to_char(value.to_string(), this.len, this.unit) } ); $crate::define_cast_evaluator!( - [<$prefix ToVarcharCastEvaluator>] { + [<$prefix:snake _to_varchar_cast_eval>] { len: Option, - unit: crate::types::CharLengthUnits + unit: $crate::types::CharLengthUnits }, $crate::types::value::DataValue::$variant(value) => |this| { $crate::types::evaluator::cast::to_varchar(value.to_string(), this.len, this.unit) } ); + #[cfg(feature = "decimal")] $crate::define_cast_evaluator!( - [<$prefix ToDecimalCastEvaluator>] { + [<$prefix:snake _to_decimal_cast_eval>] { scale: Option }, $crate::types::value::DataValue::$variant(value) => |this| { @@ -306,62 +359,66 @@ macro_rules! define_integer_cast_evaluators { macro_rules! define_float_cast_evaluators { ($prefix:ident, $variant:ident, $src_ty:ty, $from_ty:expr, $into_decimal:ident) => { paste::paste! { - $crate::define_cast_evaluator!([<$prefix ToFloatCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_float_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::$variant(*value)) }); - $crate::define_cast_evaluator!([<$prefix ToDoubleCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_double_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Float64(ordered_float::OrderedFloat(value.0 as f64))) }); - $crate::define_cast_evaluator!([<$prefix ToTinyintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_tinyint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Int8($crate::float_to_int_cast!(value.into_inner(), i8, $src_ty)?)) }); - $crate::define_cast_evaluator!([<$prefix ToSmallintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_smallint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Int16($crate::float_to_int_cast!(value.into_inner(), i16, $src_ty)?)) }); - $crate::define_cast_evaluator!([<$prefix ToIntegerCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_integer_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Int32($crate::float_to_int_cast!(value.into_inner(), i32, $src_ty)?)) }); - $crate::define_cast_evaluator!([<$prefix ToBigintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_bigint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::Int64($crate::float_to_int_cast!(value.into_inner(), i64, $src_ty)?)) }); - $crate::define_cast_evaluator!([<$prefix ToUTinyintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_utinyint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::UInt8($crate::float_to_int_cast!(value.into_inner(), u8, $src_ty)?)) }); - $crate::define_cast_evaluator!([<$prefix ToUSmallintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_usmallint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::UInt16($crate::float_to_int_cast!(value.into_inner(), u16, $src_ty)?)) }); - $crate::define_cast_evaluator!([<$prefix ToUIntegerCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_uinteger_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::UInt32($crate::float_to_int_cast!(value.into_inner(), u32, $src_ty)?)) }); - $crate::define_cast_evaluator!([<$prefix ToUBigintCastEvaluator>], $crate::types::value::DataValue::$variant(value) => { + $crate::define_cast_evaluator!([<$prefix:snake _to_ubigint_cast_eval>], $crate::types::value::DataValue::$variant(value) => { Ok($crate::types::value::DataValue::UInt64($crate::float_to_int_cast!(value.into_inner(), u64, $src_ty)?)) }); $crate::define_cast_evaluator!( - [<$prefix ToCharCastEvaluator>] { + [<$prefix:snake _to_char_cast_eval>] { len: u32, - unit: crate::types::CharLengthUnits + unit: $crate::types::CharLengthUnits }, $crate::types::value::DataValue::$variant(value) => |this| { $crate::types::evaluator::cast::to_char(value.to_string(), this.len, this.unit) } ); $crate::define_cast_evaluator!( - [<$prefix ToVarcharCastEvaluator>] { + [<$prefix:snake _to_varchar_cast_eval>] { len: Option, - unit: crate::types::CharLengthUnits + unit: $crate::types::CharLengthUnits }, $crate::types::value::DataValue::$variant(value) => |this| { $crate::types::evaluator::cast::to_varchar(value.to_string(), this.len, this.unit) } ); + #[cfg(feature = "decimal")] $crate::define_cast_evaluator!( - [<$prefix ToDecimalCastEvaluator>] { + [<$prefix:snake _to_decimal_cast_eval>] { + precision: Option, scale: Option, - to: $crate::types::LogicalType }, $crate::types::value::DataValue::$variant(value) => |this| { let mut decimal = rust_decimal::Decimal::$into_decimal(value.0).ok_or_else(|| { - $crate::types::evaluator::cast::cast_fail($from_ty, this.to.clone()) + $crate::types::evaluator::cast::cast_fail( + $from_ty, + $crate::types::LogicalType::Decimal(this.precision, this.scale), + ) })?; $crate::types::value::DataValue::decimal_round_f(&this.scale, &mut decimal); Ok($crate::types::value::DataValue::Decimal(decimal)) @@ -370,44 +427,102 @@ macro_rules! define_float_cast_evaluators { } }; } +pub fn identity_cast_eval(value: &DataValue) -> Result { + Ok(value.clone()) +} -#[derive(Debug)] -pub struct IdentityCastEvaluator; -impl CastEvaluator for IdentityCastEvaluator { - fn eval_cast(&self, value: &DataValue) -> Result { - Ok(value.clone()) - } +macro_rules! cast_ref { + ($from:expr, $to:expr) => {{ + Ok(CastEvaluatorRef::new( + cast_pos($from, $to), + CastEvaluatorParams::Unit, + )) + }}; } -macro_rules! box_cast { - ($from:expr, $to:expr, $evaluator:expr) => { - Ok(CastEvaluatorBox::new( - Arc::new($evaluator), - $from.clone(), - $to.clone(), +macro_rules! cast_string_ref { + ($from:expr, $to:expr, $len:expr, $unit:expr) => {{ + Ok(CastEvaluatorRef::new( + cast_pos($from, $to), + CastEvaluatorParams::String { + len: $len, + unit: $unit, + }, )) - }; + }}; +} + +#[cfg(feature = "decimal")] +macro_rules! cast_decimal_ref { + ($from:expr, $to:expr, $precision:expr, $scale:expr) => {{ + Ok(CastEvaluatorRef::new( + cast_pos($from, $to), + CastEvaluatorParams::Decimal { + precision: $precision, + scale: $scale, + }, + )) + }}; +} + +macro_rules! cast_precision_ref { + ($from:expr, $to:expr, $precision:expr) => {{ + Ok(CastEvaluatorRef::new( + cast_pos($from, $to), + CastEvaluatorParams::Precision { + precision: $precision, + }, + )) + }}; +} + +macro_rules! cast_timestamp_ref { + ($from:expr, $to:expr, $precision:expr, $zone:expr) => {{ + Ok(CastEvaluatorRef::new( + cast_pos($from, $to), + CastEvaluatorParams::Timestamp { + precision: $precision, + zone: $zone, + }, + )) + }}; } macro_rules! build_integer_cast { - ($prefix:ident, $to:expr, $from:expr) => {{ + ($cast:ident, $prefix:ident, $to:expr, $from:expr) => {{ paste! { match $to { - LogicalType::SqlNull => box_cast!($from, $to, ToSqlNullCastEvaluator), - LogicalType::Boolean => box_cast!($from, $to, [<$prefix ToBooleanCastEvaluator>]), - LogicalType::Tinyint => box_cast!($from, $to, [<$prefix ToTinyintCastEvaluator>]), - LogicalType::UTinyint => box_cast!($from, $to, [<$prefix ToUTinyintCastEvaluator>]), - LogicalType::Smallint => box_cast!($from, $to, [<$prefix ToSmallintCastEvaluator>]), - LogicalType::USmallint => box_cast!($from, $to, [<$prefix ToUSmallintCastEvaluator>]), - LogicalType::Integer => box_cast!($from, $to, [<$prefix ToIntegerCastEvaluator>]), - LogicalType::UInteger => box_cast!($from, $to, [<$prefix ToUIntegerCastEvaluator>]), - LogicalType::Bigint => box_cast!($from, $to, [<$prefix ToBigintCastEvaluator>]), - LogicalType::UBigint => box_cast!($from, $to, [<$prefix ToUBigintCastEvaluator>]), - LogicalType::Float => box_cast!($from, $to, [<$prefix ToFloatCastEvaluator>]), - LogicalType::Double => box_cast!($from, $to, [<$prefix ToDoubleCastEvaluator>]), - LogicalType::Char(len, unit) => box_cast!($from, $to, [<$prefix ToCharCastEvaluator>] { len: *len, unit: *unit }), - LogicalType::Varchar(len, unit) => box_cast!($from, $to, [<$prefix ToVarcharCastEvaluator>] { len: *len, unit: *unit }), - LogicalType::Decimal(_, scale) => box_cast!($from, $to, [<$prefix ToDecimalCastEvaluator>] { scale: *scale }), + LogicalType::SqlNull => $cast!($from, $to), + LogicalType::Boolean => $cast!($from, $to), + LogicalType::Tinyint => $cast!($from, $to), + LogicalType::UTinyint => $cast!($from, $to), + LogicalType::Smallint => $cast!($from, $to), + LogicalType::USmallint => $cast!($from, $to), + LogicalType::Integer => $cast!($from, $to), + LogicalType::UInteger => $cast!($from, $to), + LogicalType::Bigint => $cast!($from, $to), + LogicalType::UBigint => $cast!($from, $to), + LogicalType::Float => $cast!($from, $to), + LogicalType::Double => $cast!($from, $to), + LogicalType::Char(len, unit) => cast_string_ref!( + $from, + $to, + Some(*len), + *unit + ), + LogicalType::Varchar(len, unit) => cast_string_ref!( + $from, + $to, + *len, + *unit + ), + #[cfg(feature = "decimal")] + LogicalType::Decimal(precision, scale) => cast_decimal_ref!( + $from, + $to, + *precision, + *scale + ), _ => Err(cast_fail($from.clone(), $to.clone())), } } @@ -417,534 +532,678 @@ macro_rules! build_integer_cast { pub fn cast_create( from: Cow<'_, LogicalType>, to: Cow<'_, LogicalType>, -) -> Result { +) -> Result { let from = from.as_ref(); let to = to.as_ref(); if from == to { - return box_cast!(from, to, IdentityCastEvaluator); + return Ok(CastEvaluatorRef::new( + cast_pos(from, to), + CastEvaluatorParams::Identity, + )); + } + #[cfg(not(feature = "time"))] + if is_chrono_type(from) || is_chrono_type(to) { + return Err(DatabaseError::UnsupportedStmt( + "time types require the `time` feature".to_string(), + )); } match (from, to) { - (LogicalType::SqlNull, _) => box_cast!(from, to, NullCastEvaluator), - (_, LogicalType::SqlNull) => box_cast!(from, to, ToSqlNullCastEvaluator), + (LogicalType::SqlNull, _) => cast_ref!(from, to), + (_, LogicalType::SqlNull) => cast_ref!(from, to), (LogicalType::Boolean, LogicalType::Tinyint) => { - box_cast!(from, to, BooleanToTinyintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Boolean, LogicalType::UTinyint) => { - box_cast!(from, to, BooleanToUTinyintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Boolean, LogicalType::Smallint) => { - box_cast!(from, to, BooleanToSmallintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Boolean, LogicalType::USmallint) => { - box_cast!(from, to, BooleanToUSmallintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Boolean, LogicalType::Integer) => { - box_cast!(from, to, BooleanToIntegerCastEvaluator) + cast_ref!(from, to) } (LogicalType::Boolean, LogicalType::UInteger) => { - box_cast!(from, to, BooleanToUIntegerCastEvaluator) + cast_ref!(from, to) } (LogicalType::Boolean, LogicalType::Bigint) => { - box_cast!(from, to, BooleanToBigintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Boolean, LogicalType::UBigint) => { - box_cast!(from, to, BooleanToUBigintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Boolean, LogicalType::Float) => { - box_cast!(from, to, BooleanToFloatCastEvaluator) + cast_ref!(from, to) } (LogicalType::Boolean, LogicalType::Double) => { - box_cast!(from, to, BooleanToDoubleCastEvaluator) + cast_ref!(from, to) } (LogicalType::Boolean, LogicalType::Char(len, unit)) => { - box_cast!( - from, - to, - BooleanToCharCastEvaluator { - len: *len, - unit: *unit - } - ) + cast_string_ref!(from, to, Some(*len), *unit) } (LogicalType::Boolean, LogicalType::Varchar(len, unit)) => { - box_cast!( - from, - to, - BooleanToVarcharCastEvaluator { - len: *len, - unit: *unit - } - ) - } - (LogicalType::Tinyint, _) => build_integer_cast!(Int8, to, from), - (LogicalType::Smallint, _) => build_integer_cast!(Int16, to, from), - (LogicalType::Integer, _) => build_integer_cast!(Int32, to, from), - (LogicalType::Bigint, _) => build_integer_cast!(Int64, to, from), - (LogicalType::UTinyint, _) => build_integer_cast!(UInt8, to, from), - (LogicalType::USmallint, _) => build_integer_cast!(UInt16, to, from), - (LogicalType::UInteger, _) => build_integer_cast!(UInt32, to, from), - (LogicalType::UBigint, _) => build_integer_cast!(UInt64, to, from), + cast_string_ref!(from, to, *len, *unit) + } + (LogicalType::Tinyint, _) => build_integer_cast!(cast_ref, Int8, to, from), + (LogicalType::Smallint, _) => build_integer_cast!(cast_ref, Int16, to, from), + (LogicalType::Integer, _) => build_integer_cast!(cast_ref, Int32, to, from), + (LogicalType::Bigint, _) => build_integer_cast!(cast_ref, Int64, to, from), + (LogicalType::UTinyint, _) => build_integer_cast!(cast_ref, UInt8, to, from), + (LogicalType::USmallint, _) => build_integer_cast!(cast_ref, UInt16, to, from), + (LogicalType::UInteger, _) => build_integer_cast!(cast_ref, UInt32, to, from), + (LogicalType::UBigint, _) => build_integer_cast!(cast_ref, UInt64, to, from), (LogicalType::Float, LogicalType::Tinyint) => { - box_cast!(from, to, Float32ToTinyintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Float, LogicalType::UTinyint) => { - box_cast!(from, to, Float32ToUTinyintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Float, LogicalType::Smallint) => { - box_cast!(from, to, Float32ToSmallintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Float, LogicalType::USmallint) => { - box_cast!(from, to, Float32ToUSmallintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Float, LogicalType::Integer) => { - box_cast!(from, to, Float32ToIntegerCastEvaluator) + cast_ref!(from, to) } (LogicalType::Float, LogicalType::UInteger) => { - box_cast!(from, to, Float32ToUIntegerCastEvaluator) + cast_ref!(from, to) } (LogicalType::Float, LogicalType::Bigint) => { - box_cast!(from, to, Float32ToBigintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Float, LogicalType::UBigint) => { - box_cast!(from, to, Float32ToUBigintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Float, LogicalType::Double) => { - box_cast!(from, to, Float32ToDoubleCastEvaluator) + cast_ref!(from, to) } (LogicalType::Float, LogicalType::Char(len, unit)) => { - box_cast!( - from, - to, - Float32ToCharCastEvaluator { - len: *len, - unit: *unit - } - ) + cast_string_ref!(from, to, Some(*len), *unit) } (LogicalType::Float, LogicalType::Varchar(len, unit)) => { - box_cast!( - from, - to, - Float32ToVarcharCastEvaluator { - len: *len, - unit: *unit - } - ) - } - (LogicalType::Float, LogicalType::Decimal(_, scale)) => { - box_cast!( - from, - to, - Float32ToDecimalCastEvaluator { - scale: *scale, - to: to.clone() - } - ) + cast_string_ref!(from, to, *len, *unit) + } + #[cfg(feature = "decimal")] + (LogicalType::Float, LogicalType::Decimal(precision, scale)) => { + cast_decimal_ref!(from, to, *precision, *scale) } (LogicalType::Double, LogicalType::Float) => { - box_cast!(from, to, Float64ToFloatCastEvaluator) + cast_ref!(from, to) } (LogicalType::Double, LogicalType::Tinyint) => { - box_cast!(from, to, Float64ToTinyintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Double, LogicalType::UTinyint) => { - box_cast!(from, to, Float64ToUTinyintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Double, LogicalType::Smallint) => { - box_cast!(from, to, Float64ToSmallintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Double, LogicalType::USmallint) => { - box_cast!(from, to, Float64ToUSmallintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Double, LogicalType::Integer) => { - box_cast!(from, to, Float64ToIntegerCastEvaluator) + cast_ref!(from, to) } (LogicalType::Double, LogicalType::UInteger) => { - box_cast!(from, to, Float64ToUIntegerCastEvaluator) + cast_ref!(from, to) } (LogicalType::Double, LogicalType::Bigint) => { - box_cast!(from, to, Float64ToBigintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Double, LogicalType::UBigint) => { - box_cast!(from, to, Float64ToUBigintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Double, LogicalType::Char(len, unit)) => { - box_cast!( - from, - to, - Float64ToCharCastEvaluator { - len: *len, - unit: *unit - } - ) + cast_string_ref!(from, to, Some(*len), *unit) } (LogicalType::Double, LogicalType::Varchar(len, unit)) => { - box_cast!( - from, - to, - Float64ToVarcharCastEvaluator { - len: *len, - unit: *unit - } - ) - } - (LogicalType::Double, LogicalType::Decimal(_, scale)) => { - box_cast!( - from, - to, - Float64ToDecimalCastEvaluator { - scale: *scale, - to: to.clone() - } - ) + cast_string_ref!(from, to, *len, *unit) + } + #[cfg(feature = "decimal")] + (LogicalType::Double, LogicalType::Decimal(precision, scale)) => { + cast_decimal_ref!(from, to, *precision, *scale) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Boolean) => { - box_cast!(from, to, Utf8ToBooleanCastEvaluator { from: from.clone() }) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Tinyint) => { - box_cast!(from, to, Utf8ToTinyintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::UTinyint) => { - box_cast!(from, to, Utf8ToUTinyintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Smallint) => { - box_cast!(from, to, Utf8ToSmallintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::USmallint) => { - box_cast!(from, to, Utf8ToUSmallintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Integer) => { - box_cast!(from, to, Utf8ToIntegerCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::UInteger) => { - box_cast!(from, to, Utf8ToUIntegerCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Bigint) => { - box_cast!(from, to, Utf8ToBigintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::UBigint) => { - box_cast!(from, to, Utf8ToUBigintCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Float) => { - box_cast!(from, to, Utf8ToFloatCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Double) => { - box_cast!(from, to, Utf8ToDoubleCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Char(len, unit)) => { - box_cast!( - from, - to, - Utf8ToCharCastEvaluator { - len: *len, - unit: *unit - } - ) + cast_string_ref!(from, to, Some(*len), *unit) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Varchar(len, unit)) => { - box_cast!( - from, - to, - Utf8ToVarcharCastEvaluator { - len: *len, - unit: *unit - } - ) + cast_string_ref!(from, to, *len, *unit) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Date) => { - box_cast!(from, to, Utf8ToDateCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::DateTime) => { - box_cast!(from, to, Utf8ToDatetimeCastEvaluator) + cast_ref!(from, to) } (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Time(precision)) => { - box_cast!( - from, - to, - Utf8ToTimeCastEvaluator { - precision: *precision - } - ) + cast_precision_ref!(from, to, *precision) } ( LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::TimeStamp(precision, zone), ) => { - box_cast!( - from, - to, - Utf8ToTimestampCastEvaluator { - precision: *precision, - zone: *zone, - to: to.clone() - } - ) + cast_timestamp_ref!(from, to, *precision, *zone) } + #[cfg(feature = "decimal")] (LogicalType::Char(_, _) | LogicalType::Varchar(_, _), LogicalType::Decimal(_, _)) => { - box_cast!(from, to, Utf8ToDecimalCastEvaluator) - } - (LogicalType::Date, LogicalType::Char(len, unit)) => { - box_cast!( - from, - to, - Date32ToCharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - } - ) - } - (LogicalType::Date, LogicalType::Varchar(len, unit)) => { - box_cast!( - from, - to, - Date32ToVarcharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - } - ) + cast_ref!(from, to) + } + (LogicalType::Date, LogicalType::Char(_, _)) => { + cast_string_ref!(from, to, Some(char_params(to).0), char_params(to).1) + } + (LogicalType::Date, LogicalType::Varchar(_, _)) => { + let (len, unit) = varchar_params(to); + cast_string_ref!(from, to, len, unit) } (LogicalType::Date, LogicalType::DateTime) => { - box_cast!(from, to, Date32ToDatetimeCastEvaluator { to: to.clone() }) - } - (LogicalType::DateTime, LogicalType::Char(len, unit)) => { - box_cast!( - from, - to, - Date64ToCharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - } - ) - } - (LogicalType::DateTime, LogicalType::Varchar(len, unit)) => { - box_cast!( - from, - to, - Date64ToVarcharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - } - ) + cast_ref!(from, to) + } + (LogicalType::DateTime, LogicalType::Char(_, _)) => { + cast_string_ref!(from, to, Some(char_params(to).0), char_params(to).1) + } + (LogicalType::DateTime, LogicalType::Varchar(_, _)) => { + let (len, unit) = varchar_params(to); + cast_string_ref!(from, to, len, unit) } (LogicalType::DateTime, LogicalType::Date) => { - box_cast!(from, to, Date64ToDateCastEvaluator { to: to.clone() }) + cast_ref!(from, to) } (LogicalType::DateTime, LogicalType::Time(precision)) => { - box_cast!( - from, - to, - Date64ToTimeCastEvaluator { - precision: *precision, - to: to.clone() - } - ) + cast_precision_ref!(from, to, *precision) } (LogicalType::DateTime, LogicalType::TimeStamp(precision, zone)) => { - box_cast!( - from, - to, - Date64ToTimestampCastEvaluator { - precision: *precision, - zone: *zone - } - ) - } - (LogicalType::Time(_), LogicalType::Char(len, unit)) => { - box_cast!( - from, - to, - Time32ToCharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - } - ) - } - (LogicalType::Time(_), LogicalType::Varchar(len, unit)) => { - box_cast!( - from, - to, - Time32ToVarcharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - } - ) + cast_timestamp_ref!(from, to, *precision, *zone) + } + (LogicalType::Time(_), LogicalType::Char(_, _)) => { + cast_string_ref!(from, to, Some(char_params(to).0), char_params(to).1) + } + (LogicalType::Time(_), LogicalType::Varchar(_, _)) => { + let (len, unit) = varchar_params(to); + cast_string_ref!(from, to, len, unit) } (LogicalType::Time(_), LogicalType::Time(precision)) => { - box_cast!( - from, - to, - Time32ToTimeCastEvaluator { - precision: *precision - } - ) - } - (LogicalType::TimeStamp(_, _), LogicalType::Char(len, unit)) => { - box_cast!( - from, - to, - Time64ToCharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - } - ) - } - (LogicalType::TimeStamp(_, _), LogicalType::Varchar(len, unit)) => { - box_cast!( - from, - to, - Time64ToVarcharCastEvaluator { - len: *len, - unit: *unit, - to: to.clone() - } - ) + cast_precision_ref!(from, to, *precision) + } + (LogicalType::TimeStamp(_, _), LogicalType::Char(_, _)) => { + cast_string_ref!(from, to, Some(char_params(to).0), char_params(to).1) + } + (LogicalType::TimeStamp(_, _), LogicalType::Varchar(_, _)) => { + let (len, unit) = varchar_params(to); + cast_string_ref!(from, to, len, unit) } (LogicalType::TimeStamp(_, _), LogicalType::Date) => { - box_cast!( - from, - to, - Time64ToDateCastEvaluator { - from: from.clone(), - to: to.clone() - } - ) + cast_ref!(from, to) } (LogicalType::TimeStamp(_, _), LogicalType::DateTime) => { - box_cast!( - from, - to, - Time64ToDatetimeCastEvaluator { - from: from.clone(), - to: to.clone() - } - ) + cast_ref!(from, to) } (LogicalType::TimeStamp(_, _), LogicalType::Time(precision)) => { - box_cast!( - from, - to, - Time64ToTimeCastEvaluator { - precision: *precision, - from: from.clone(), - to: to.clone() - } - ) + cast_precision_ref!(from, to, *precision) } (LogicalType::TimeStamp(_, _), LogicalType::TimeStamp(precision, zone)) => { - box_cast!( - from, - to, - Time64ToTimestampCastEvaluator { - precision: *precision, - zone: *zone - } - ) - } - (LogicalType::Decimal(_, _), LogicalType::Float) => { - box_cast!(from, to, DecimalToFloatCastEvaluator) - } - (LogicalType::Decimal(_, _), LogicalType::Double) => { - box_cast!(from, to, DecimalToDoubleCastEvaluator) - } - (LogicalType::Decimal(_, _), LogicalType::Decimal(_, _)) => { - box_cast!(from, to, DecimalToDecimalCastEvaluator) + cast_timestamp_ref!(from, to, *precision, *zone) + } + #[cfg(feature = "decimal")] + (LogicalType::Decimal(_, _), to) => match to { + LogicalType::Float + | LogicalType::Double + | LogicalType::Decimal(_, _) + | LogicalType::Tinyint + | LogicalType::Smallint + | LogicalType::Integer + | LogicalType::Bigint + | LogicalType::UTinyint + | LogicalType::USmallint + | LogicalType::UInteger + | LogicalType::UBigint => cast_ref!(from, to), + LogicalType::Char(len, unit) => cast_string_ref!(from, to, Some(*len), *unit), + LogicalType::Varchar(len, unit) => cast_string_ref!(from, to, *len, *unit), + _ => Err(cast_fail(from.clone(), to.clone())), + }, + (LogicalType::Tuple(from_types), LogicalType::Tuple(to_types)) => { + let evaluators = from_types + .iter() + .zip(to_types.iter()) + .map(|(from, to)| cast_create(Cow::Borrowed(from), Cow::Borrowed(to))) + .collect::, _>>()?; + Ok(CastEvaluatorRef::new( + cast_pos(from, to), + CastEvaluatorParams::Tuple { evaluators }, + )) } - (LogicalType::Decimal(_, _), LogicalType::Char(len, unit)) => { - box_cast!( - from, - to, - DecimalToCharCastEvaluator { - len: *len, - unit: *unit + _ => Err(cast_fail(from.clone(), to.clone())), + } +} + +fn char_params(ty: &LogicalType) -> (u32, CharLengthUnits) { + let LogicalType::Char(len, unit) = ty else { + unreachable!("cast target must be char") + }; + (*len, *unit) +} + +fn varchar_params(ty: &LogicalType) -> (Option, CharLengthUnits) { + let LogicalType::Varchar(len, unit) = ty else { + unreachable!("cast target must be varchar") + }; + (*len, *unit) +} + +fn string_param(params: &CastEvaluatorParams) -> (Option, CharLengthUnits) { + let CastEvaluatorParams::String { len, unit } = params else { + unreachable!("cast evaluator must have string parameters") + }; + (*len, *unit) +} + +#[cfg(feature = "decimal")] +fn decimal_param(params: &CastEvaluatorParams) -> (Option, Option) { + let CastEvaluatorParams::Decimal { precision, scale } = params else { + unreachable!("cast evaluator must have decimal parameters") + }; + (*precision, *scale) +} + +#[cfg(feature = "time")] +fn precision_param(params: &CastEvaluatorParams) -> Option { + let CastEvaluatorParams::Precision { precision } = params else { + unreachable!("cast evaluator must have precision parameter") + }; + *precision +} + +#[cfg(feature = "time")] +fn timestamp_param(params: &CastEvaluatorParams) -> (Option, bool) { + let CastEvaluatorParams::Timestamp { precision, zone } = params else { + unreachable!("cast evaluator must have timestamp parameters") + }; + (*precision, *zone) +} + +macro_rules! eval_integer_cast_by_pos { + ($prefix:ident, $to_code:expr, $params:expr, $value:expr) => {{ + let params = $params; + paste! { + match $to_code { + CAST_SQL_NULL => to_sql_null_cast_eval($value), + CAST_BOOLEAN => [<$prefix:snake _to_boolean_cast_eval>]($value), + CAST_TINYINT => [<$prefix:snake _to_tinyint_cast_eval>]($value), + CAST_UTINYINT => [<$prefix:snake _to_utinyint_cast_eval>]($value), + CAST_SMALLINT => [<$prefix:snake _to_smallint_cast_eval>]($value), + CAST_USMALLINT => [<$prefix:snake _to_usmallint_cast_eval>]($value), + CAST_INTEGER => [<$prefix:snake _to_integer_cast_eval>]($value), + CAST_UINTEGER => [<$prefix:snake _to_uinteger_cast_eval>]($value), + CAST_BIGINT => [<$prefix:snake _to_bigint_cast_eval>]($value), + CAST_UBIGINT => [<$prefix:snake _to_ubigint_cast_eval>]($value), + CAST_FLOAT => [<$prefix:snake _to_float_cast_eval>]($value), + CAST_DOUBLE => [<$prefix:snake _to_double_cast_eval>]($value), + CAST_CHAR => { + let (len, unit) = string_param(params); + let len = len.expect("char cast must have fixed length"); + [<$prefix:snake _to_char_cast_eval>](len, unit, $value) } - ) - } - (LogicalType::Decimal(_, _), LogicalType::Varchar(len, unit)) => { - box_cast!( - from, - to, - DecimalToVarcharCastEvaluator { - len: *len, - unit: *unit + CAST_VARCHAR => { + let (len, unit) = string_param(params); + [<$prefix:snake _to_varchar_cast_eval>](len, unit, $value) } - ) - } - (LogicalType::Decimal(_, _), LogicalType::Tinyint) => { - box_cast!(from, to, DecimalToTinyintCastEvaluator) - } - (LogicalType::Decimal(_, _), LogicalType::Smallint) => { - box_cast!(from, to, DecimalToSmallintCastEvaluator) - } - (LogicalType::Decimal(_, _), LogicalType::Integer) => { - box_cast!(from, to, DecimalToIntegerCastEvaluator) - } - (LogicalType::Decimal(_, _), LogicalType::Bigint) => { - box_cast!(from, to, DecimalToBigintCastEvaluator) - } - (LogicalType::Decimal(_, _), LogicalType::UTinyint) => { - box_cast!(from, to, DecimalToUTinyintCastEvaluator) + #[cfg(feature = "decimal")] + CAST_DECIMAL => { + let (_, scale) = decimal_param(params); + [<$prefix:snake _to_decimal_cast_eval>](scale, $value) + } + _ => unreachable!("invalid integer cast evaluator position"), + } } - (LogicalType::Decimal(_, _), LogicalType::USmallint) => { - box_cast!(from, to, DecimalToUSmallintCastEvaluator) + }}; +} + +impl CastEvaluatorRef { + pub fn eval(&self, value: &DataValue) -> Result { + let params = &self.params; + let from_code = self.pos / CAST_TYPE_STRIDE; + let to_code = self.pos % CAST_TYPE_STRIDE; + + macro_rules! run { + ($evaluator:ident) => { + $evaluator(value) + }; + ($evaluator:ident { $($field:ident),+ $(,)? }) => { + $evaluator($($field,)+ value) + }; + ($evaluator:ident { $($field:ident : $field_value:expr),+ $(,)? }) => { + $evaluator($($field_value,)+ value) + }; } - (LogicalType::Decimal(_, _), LogicalType::UInteger) => { - box_cast!(from, to, DecimalToUIntegerCastEvaluator) + + if matches!(params, CastEvaluatorParams::Identity) { + return run!(identity_cast_eval); } - (LogicalType::Decimal(_, _), LogicalType::UBigint) => { - box_cast!(from, to, DecimalToUBigintCastEvaluator) + #[cfg(not(feature = "time"))] + if is_chrono_type_code(from_code) || is_chrono_type_code(to_code) { + return Err(DatabaseError::UnsupportedStmt( + "time types require the `time` feature".to_string(), + )); } - (LogicalType::Tuple(from_types), LogicalType::Tuple(to_types)) => { - let evaluators = from_types - .iter() - .zip(to_types.iter()) - .map(|(from, to)| cast_create(Cow::Borrowed(from), Cow::Borrowed(to))) - .collect::, _>>()?; - box_cast!( - from, - to, - TupleCastEvaluator { - element_evaluators: evaluators + + match (from_code, to_code) { + (CAST_SQL_NULL, _) => run!(null_cast_eval), + (_, CAST_SQL_NULL) => run!(to_sql_null_cast_eval), + (CAST_BOOLEAN, CAST_TINYINT) => run!(boolean_to_tinyint_cast_eval), + (CAST_BOOLEAN, CAST_UTINYINT) => run!(boolean_to_utinyint_cast_eval), + (CAST_BOOLEAN, CAST_SMALLINT) => run!(boolean_to_smallint_cast_eval), + (CAST_BOOLEAN, CAST_USMALLINT) => run!(boolean_to_usmallint_cast_eval), + (CAST_BOOLEAN, CAST_INTEGER) => run!(boolean_to_integer_cast_eval), + (CAST_BOOLEAN, CAST_UINTEGER) => run!(boolean_to_uinteger_cast_eval), + (CAST_BOOLEAN, CAST_BIGINT) => run!(boolean_to_bigint_cast_eval), + (CAST_BOOLEAN, CAST_UBIGINT) => run!(boolean_to_ubigint_cast_eval), + (CAST_BOOLEAN, CAST_FLOAT) => run!(boolean_to_float_cast_eval), + (CAST_BOOLEAN, CAST_DOUBLE) => run!(boolean_to_double_cast_eval), + (CAST_BOOLEAN, CAST_CHAR) => { + let (len, unit) = string_param(params); + let len = len.expect("char cast must have fixed length"); + run!(boolean_to_char_cast_eval { len, unit }) + } + (CAST_BOOLEAN, CAST_VARCHAR) => { + let (len, unit) = string_param(params); + run!(boolean_to_varchar_cast_eval { len, unit }) + } + (CAST_TINYINT, _) => eval_integer_cast_by_pos!(Int8, to_code, params, value), + (CAST_SMALLINT, _) => eval_integer_cast_by_pos!(Int16, to_code, params, value), + (CAST_INTEGER, _) => eval_integer_cast_by_pos!(Int32, to_code, params, value), + (CAST_BIGINT, _) => eval_integer_cast_by_pos!(Int64, to_code, params, value), + (CAST_UTINYINT, _) => eval_integer_cast_by_pos!(Uint8, to_code, params, value), + (CAST_USMALLINT, _) => eval_integer_cast_by_pos!(Uint16, to_code, params, value), + (CAST_UINTEGER, _) => eval_integer_cast_by_pos!(Uint32, to_code, params, value), + (CAST_UBIGINT, _) => eval_integer_cast_by_pos!(Uint64, to_code, params, value), + (CAST_FLOAT, CAST_TINYINT) => run!(float32_to_tinyint_cast_eval), + (CAST_FLOAT, CAST_UTINYINT) => run!(float32_to_utinyint_cast_eval), + (CAST_FLOAT, CAST_SMALLINT) => run!(float32_to_smallint_cast_eval), + (CAST_FLOAT, CAST_USMALLINT) => run!(float32_to_usmallint_cast_eval), + (CAST_FLOAT, CAST_INTEGER) => run!(float32_to_integer_cast_eval), + (CAST_FLOAT, CAST_UINTEGER) => run!(float32_to_uinteger_cast_eval), + (CAST_FLOAT, CAST_BIGINT) => run!(float32_to_bigint_cast_eval), + (CAST_FLOAT, CAST_UBIGINT) => run!(float32_to_ubigint_cast_eval), + (CAST_FLOAT, CAST_DOUBLE) => run!(float32_to_double_cast_eval), + (CAST_FLOAT, CAST_CHAR) => { + let (len, unit) = string_param(params); + let len = len.expect("char cast must have fixed length"); + run!(float32_to_char_cast_eval { len, unit }) + } + (CAST_FLOAT, CAST_VARCHAR) => { + let (len, unit) = string_param(params); + run!(float32_to_varchar_cast_eval { len, unit }) + } + #[cfg(feature = "decimal")] + (CAST_FLOAT, CAST_DECIMAL) => { + let (precision, scale) = decimal_param(params); + run!(float32_to_decimal_cast_eval { precision, scale }) + } + (CAST_DOUBLE, CAST_FLOAT) => run!(float64_to_float_cast_eval), + (CAST_DOUBLE, CAST_TINYINT) => run!(float64_to_tinyint_cast_eval), + (CAST_DOUBLE, CAST_UTINYINT) => run!(float64_to_utinyint_cast_eval), + (CAST_DOUBLE, CAST_SMALLINT) => run!(float64_to_smallint_cast_eval), + (CAST_DOUBLE, CAST_USMALLINT) => run!(float64_to_usmallint_cast_eval), + (CAST_DOUBLE, CAST_INTEGER) => run!(float64_to_integer_cast_eval), + (CAST_DOUBLE, CAST_UINTEGER) => run!(float64_to_uinteger_cast_eval), + (CAST_DOUBLE, CAST_BIGINT) => run!(float64_to_bigint_cast_eval), + (CAST_DOUBLE, CAST_UBIGINT) => run!(float64_to_ubigint_cast_eval), + (CAST_DOUBLE, CAST_CHAR) => { + let (len, unit) = string_param(params); + let len = len.expect("char cast must have fixed length"); + run!(float64_to_char_cast_eval { len, unit }) + } + (CAST_DOUBLE, CAST_VARCHAR) => { + let (len, unit) = string_param(params); + run!(float64_to_varchar_cast_eval { len, unit }) + } + #[cfg(feature = "decimal")] + (CAST_DOUBLE, CAST_DECIMAL) => { + let (precision, scale) = decimal_param(params); + run!(float64_to_decimal_cast_eval { precision, scale }) + } + (CAST_CHAR | CAST_VARCHAR, CAST_BOOLEAN) => { + run!(utf8_to_boolean_cast_eval) + } + (CAST_CHAR | CAST_VARCHAR, CAST_TINYINT) => run!(utf8_to_tinyint_cast_eval), + (CAST_CHAR | CAST_VARCHAR, CAST_UTINYINT) => run!(utf8_to_utinyint_cast_eval), + (CAST_CHAR | CAST_VARCHAR, CAST_SMALLINT) => run!(utf8_to_smallint_cast_eval), + (CAST_CHAR | CAST_VARCHAR, CAST_USMALLINT) => run!(utf8_to_usmallint_cast_eval), + (CAST_CHAR | CAST_VARCHAR, CAST_INTEGER) => run!(utf8_to_integer_cast_eval), + (CAST_CHAR | CAST_VARCHAR, CAST_UINTEGER) => run!(utf8_to_uinteger_cast_eval), + (CAST_CHAR | CAST_VARCHAR, CAST_BIGINT) => run!(utf8_to_bigint_cast_eval), + (CAST_CHAR | CAST_VARCHAR, CAST_UBIGINT) => run!(utf8_to_ubigint_cast_eval), + (CAST_CHAR | CAST_VARCHAR, CAST_FLOAT) => run!(utf8_to_float_cast_eval), + (CAST_CHAR | CAST_VARCHAR, CAST_DOUBLE) => run!(utf8_to_double_cast_eval), + (CAST_CHAR | CAST_VARCHAR, CAST_CHAR) => { + let (len, unit) = string_param(params); + let len = len.expect("char cast must have fixed length"); + run!(utf8_to_char_cast_eval { len, unit }) + } + (CAST_CHAR | CAST_VARCHAR, CAST_VARCHAR) => { + let (len, unit) = string_param(params); + run!(utf8_to_varchar_cast_eval { len, unit }) + } + #[cfg(feature = "time")] + (CAST_CHAR | CAST_VARCHAR, CAST_DATE) => run!(utf8_to_date_cast_eval), + #[cfg(feature = "time")] + (CAST_CHAR | CAST_VARCHAR, CAST_DATETIME) => run!(utf8_to_datetime_cast_eval), + #[cfg(feature = "time")] + (CAST_CHAR | CAST_VARCHAR, CAST_TIME) => run!(utf8_to_time_cast_eval { + precision: precision_param(params) + }), + #[cfg(feature = "time")] + (CAST_CHAR | CAST_VARCHAR, CAST_TIMESTAMP) => { + let (precision, zone) = timestamp_param(params); + run!(utf8_to_timestamp_cast_eval { precision, zone }) + } + #[cfg(feature = "decimal")] + (CAST_CHAR | CAST_VARCHAR, CAST_DECIMAL) => run!(utf8_to_decimal_cast_eval), + #[cfg(feature = "time")] + (CAST_DATE, CAST_CHAR) => { + let (len, unit) = string_param(params); + let len = len.expect("char cast must have fixed length"); + run!(date32_to_char_cast_eval { len, unit }) + } + #[cfg(feature = "time")] + (CAST_DATE, CAST_VARCHAR) => { + let (len, unit) = string_param(params); + run!(date32_to_varchar_cast_eval { len, unit }) + } + #[cfg(feature = "time")] + (CAST_DATE, CAST_DATETIME) => run!(date32_to_datetime_cast_eval), + #[cfg(feature = "time")] + (CAST_DATETIME, CAST_CHAR) => { + let (len, unit) = string_param(params); + let len = len.expect("char cast must have fixed length"); + run!(date64_to_char_cast_eval { len, unit }) + } + #[cfg(feature = "time")] + (CAST_DATETIME, CAST_VARCHAR) => { + let (len, unit) = string_param(params); + run!(date64_to_varchar_cast_eval { len, unit }) + } + #[cfg(feature = "time")] + (CAST_DATETIME, CAST_DATE) => run!(date64_to_date_cast_eval), + #[cfg(feature = "time")] + (CAST_DATETIME, CAST_TIME) => run!(date64_to_time_cast_eval { + precision: precision_param(params) + }), + #[cfg(feature = "time")] + (CAST_DATETIME, CAST_TIMESTAMP) => { + let (precision, zone) = timestamp_param(params); + run!(date64_to_timestamp_cast_eval { precision, zone }) + } + #[cfg(feature = "time")] + (CAST_TIME, CAST_CHAR) => { + let (len, unit) = string_param(params); + let len = len.expect("char cast must have fixed length"); + run!(time32_to_char_cast_eval { len, unit }) + } + #[cfg(feature = "time")] + (CAST_TIME, CAST_VARCHAR) => { + let (len, unit) = string_param(params); + run!(time32_to_varchar_cast_eval { len, unit }) + } + #[cfg(feature = "time")] + (CAST_TIME, CAST_TIME) => run!(time32_to_time_cast_eval { + precision: precision_param(params) + }), + #[cfg(feature = "time")] + (CAST_TIMESTAMP, CAST_CHAR) => { + let (len, unit) = string_param(params); + let len = len.expect("char cast must have fixed length"); + run!(time64_to_char_cast_eval { len, unit }) + } + #[cfg(feature = "time")] + (CAST_TIMESTAMP, CAST_VARCHAR) => { + let (len, unit) = string_param(params); + run!(time64_to_varchar_cast_eval { len, unit }) + } + #[cfg(feature = "time")] + (CAST_TIMESTAMP, CAST_DATE) => run!(time64_to_date_cast_eval), + #[cfg(feature = "time")] + (CAST_TIMESTAMP, CAST_DATETIME) => run!(time64_to_datetime_cast_eval), + #[cfg(feature = "time")] + (CAST_TIMESTAMP, CAST_TIME) => run!(time64_to_time_cast_eval { + precision: precision_param(params) + }), + #[cfg(feature = "time")] + (CAST_TIMESTAMP, CAST_TIMESTAMP) => { + let (precision, zone) = timestamp_param(params); + run!(time64_to_timestamp_cast_eval { precision, zone }) + } + #[cfg(feature = "decimal")] + (CAST_DECIMAL, to_code) => match to_code { + CAST_FLOAT => run!(decimal_to_float_cast_eval), + CAST_DOUBLE => run!(decimal_to_double_cast_eval), + CAST_DECIMAL => run!(decimal_to_decimal_cast_eval), + CAST_CHAR => { + let (len, unit) = string_param(params); + let len = len.expect("char cast must have fixed length"); + run!(decimal_to_char_cast_eval { len, unit }) } - ) + CAST_VARCHAR => { + let (len, unit) = string_param(params); + run!(decimal_to_varchar_cast_eval { len, unit }) + } + CAST_TINYINT => run!(decimal_to_tinyint_cast_eval), + CAST_SMALLINT => run!(decimal_to_smallint_cast_eval), + CAST_INTEGER => run!(decimal_to_integer_cast_eval), + CAST_BIGINT => run!(decimal_to_bigint_cast_eval), + CAST_UTINYINT => run!(decimal_to_utinyint_cast_eval), + CAST_USMALLINT => run!(decimal_to_usmallint_cast_eval), + CAST_UINTEGER => run!(decimal_to_uinteger_cast_eval), + CAST_UBIGINT => run!(decimal_to_ubigint_cast_eval), + _ => unreachable!("invalid decimal cast evaluator position"), + }, + (CAST_TUPLE, CAST_TUPLE) => { + let CastEvaluatorParams::Tuple { evaluators } = params else { + unreachable!("tuple cast must have tuple parameters") + }; + eval_tuple_cast(evaluators, value) + } + _ => unreachable!("invalid cast evaluator position"), } - _ => Err(cast_fail(from.clone(), to.clone())), } } #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use super::cast_create; + use super::*; use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; use crate::storage::rocksdb::RocksTransaction; - use crate::types::evaluator::CastEvaluatorBox; + use crate::types::evaluator::CastEvaluatorRef; use crate::types::LogicalType; use std::borrow::Cow; use std::io::{Cursor, Seek, SeekFrom}; - fn create(from: LogicalType, to: LogicalType) -> Result { + fn create(from: LogicalType, to: LogicalType) -> Result { cast_create(Cow::Owned(from), Cow::Owned(to)) } + #[test] + fn test_cast_evaluator_positions_are_stable() -> Result<(), DatabaseError> { + assert_eq!( + create(LogicalType::Integer, LogicalType::Bigint)?.pos(), + CAST_INTEGER * CAST_TYPE_STRIDE + CAST_BIGINT + ); + assert_eq!( + create( + LogicalType::Varchar(None, crate::types::CharLengthUnits::Characters), + LogicalType::Date + )? + .pos(), + CAST_VARCHAR * CAST_TYPE_STRIDE + CAST_DATE + ); + assert_eq!( + create( + LogicalType::TimeStamp(Some(3), true), + LogicalType::Time(Some(0)) + )? + .pos(), + CAST_TIMESTAMP * CAST_TYPE_STRIDE + CAST_TIME + ); + + Ok(()) + } + #[test] fn test_cast_evaluator_serialization() -> Result<(), DatabaseError> { let evaluator = create(LogicalType::Integer, LogicalType::Bigint)?; let mut cursor = Cursor::new(Vec::new()); let mut reference_tables = ReferenceTables::new(); + let mut arena = crate::planner::TableArena::default(); - evaluator.encode(&mut cursor, false, &mut reference_tables)?; + evaluator.encode(&mut cursor, false, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; assert_eq!( - CastEvaluatorBox::decode::(&mut cursor, None, &reference_tables)?, + CastEvaluatorRef::decode::( + &mut cursor, + None, + &reference_tables, + &mut arena + )?, evaluator ); diff --git a/src/types/evaluator/date.rs b/src/types/evaluator/date.rs index 16457edb..d88b3140 100644 --- a/src/types/evaluator/date.rs +++ b/src/types/evaluator/date.rs @@ -21,42 +21,40 @@ use chrono::NaiveDate; numeric_binary_evaluator_definition!(Date, DataValue::Date32); crate::define_cast_evaluator!( - Date32ToCharCastEvaluator { + date32_to_char_cast_eval { len: u32, - unit: CharLengthUnits, - to: LogicalType + unit: CharLengthUnits }, DataValue::Date32(value) => |this| { to_char( - DataValue::format_date(*value).ok_or_else(|| cast_fail(LogicalType::Date, this.to.clone()))?, + DataValue::format_date(*value).ok_or_else(|| { + cast_fail(LogicalType::Date, LogicalType::Char(this.len, this.unit)) + })?, this.len, this.unit, ) } ); crate::define_cast_evaluator!( - Date32ToVarcharCastEvaluator { + date32_to_varchar_cast_eval { len: Option, - unit: CharLengthUnits, - to: LogicalType + unit: CharLengthUnits }, DataValue::Date32(value) => |this| { to_varchar( - DataValue::format_date(*value).ok_or_else(|| cast_fail(LogicalType::Date, this.to.clone()))?, + DataValue::format_date(*value).ok_or_else(|| { + cast_fail(LogicalType::Date, LogicalType::Varchar(this.len, this.unit)) + })?, this.len, this.unit, ) } ); -crate::define_cast_evaluator!( - Date32ToDatetimeCastEvaluator { - to: LogicalType - }, - DataValue::Date32(value) => |this| { +crate::define_cast_evaluator!(date32_to_datetime_cast_eval, DataValue::Date32(value) => { let value = NaiveDate::from_num_days_from_ce_opt(*value) - .ok_or_else(|| cast_fail(LogicalType::Date, this.to.clone()))? + .ok_or_else(|| cast_fail(LogicalType::Date, LogicalType::DateTime))? .and_hms_opt(0, 0, 0) - .ok_or_else(|| cast_fail(LogicalType::Date, this.to.clone()))? + .ok_or_else(|| cast_fail(LogicalType::Date, LogicalType::DateTime))? .and_utc() .timestamp(); @@ -67,7 +65,6 @@ crate::define_cast_evaluator!( #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use chrono::Datelike; @@ -79,13 +76,7 @@ mod test { .num_days_from_ce(), ); assert_eq!( - Date32ToCharCastEvaluator { - len: 10, - unit: CharLengthUnits::Characters, - to: LogicalType::Char(10, CharLengthUnits::Characters), - } - .eval_cast(&value) - .unwrap(), + date32_to_char_cast_eval(10, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "2024-01-02".to_string(), ty: Utf8Type::Fixed(10), @@ -93,13 +84,7 @@ mod test { } ); assert_eq!( - Date32ToVarcharCastEvaluator { - len: Some(10), - unit: CharLengthUnits::Characters, - to: LogicalType::Varchar(Some(10), CharLengthUnits::Characters), - } - .eval_cast(&value) - .unwrap(), + date32_to_varchar_cast_eval(Some(10), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "2024-01-02".to_string(), ty: Utf8Type::Variable(Some(10)), @@ -107,11 +92,7 @@ mod test { } ); assert_eq!( - Date32ToDatetimeCastEvaluator { - to: LogicalType::DateTime, - } - .eval_cast(&value) - .unwrap(), + date32_to_datetime_cast_eval(&value).unwrap(), DataValue::Date64( NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() diff --git a/src/types/evaluator/datetime.rs b/src/types/evaluator/datetime.rs index 53094446..f4fe9819 100644 --- a/src/types/evaluator/datetime.rs +++ b/src/types/evaluator/datetime.rs @@ -21,40 +21,38 @@ use chrono::{DateTime, Datelike, Timelike}; numeric_binary_evaluator_definition!(DateTime, DataValue::Date64); crate::define_cast_evaluator!( - Date64ToCharCastEvaluator { + date64_to_char_cast_eval { len: u32, - unit: CharLengthUnits, - to: LogicalType + unit: CharLengthUnits }, DataValue::Date64(value) => |this| { to_char( - DataValue::format_datetime(*value).ok_or_else(|| cast_fail(LogicalType::DateTime, this.to.clone()))?, + DataValue::format_datetime(*value).ok_or_else(|| { + cast_fail(LogicalType::DateTime, LogicalType::Char(this.len, this.unit)) + })?, this.len, this.unit, ) } ); crate::define_cast_evaluator!( - Date64ToVarcharCastEvaluator { + date64_to_varchar_cast_eval { len: Option, - unit: CharLengthUnits, - to: LogicalType + unit: CharLengthUnits }, DataValue::Date64(value) => |this| { to_varchar( - DataValue::format_datetime(*value).ok_or_else(|| cast_fail(LogicalType::DateTime, this.to.clone()))?, + DataValue::format_datetime(*value).ok_or_else(|| { + cast_fail(LogicalType::DateTime, LogicalType::Varchar(this.len, this.unit)) + })?, this.len, this.unit, ) } ); -crate::define_cast_evaluator!( - Date64ToDateCastEvaluator { - to: LogicalType - }, - DataValue::Date64(value) => |this| { +crate::define_cast_evaluator!(date64_to_date_cast_eval, DataValue::Date64(value) => { let value = DateTime::from_timestamp(*value, 0) - .ok_or_else(|| cast_fail(LogicalType::DateTime, this.to.clone()))? + .ok_or_else(|| cast_fail(LogicalType::DateTime, LogicalType::Date))? .naive_utc() .date() .num_days_from_ce(); @@ -63,21 +61,22 @@ crate::define_cast_evaluator!( } ); crate::define_cast_evaluator!( - Date64ToTimeCastEvaluator { - precision: Option, - to: LogicalType + date64_to_time_cast_eval { + precision: Option }, DataValue::Date64(value) => |this| { let precision = this.precision.unwrap_or(0); let value = DateTime::from_timestamp(*value, 0) .map(|date_time| date_time.time().num_seconds_from_midnight()) - .ok_or_else(|| cast_fail(LogicalType::DateTime, this.to.clone()))?; + .ok_or_else(|| { + cast_fail(LogicalType::DateTime, LogicalType::Time(this.precision)) + })?; Ok(DataValue::Time32(DataValue::pack(value, 0, 0), precision)) } ); crate::define_cast_evaluator!( - Date64ToTimestampCastEvaluator { + date64_to_timestamp_cast_eval { precision: Option, zone: bool }, @@ -89,7 +88,6 @@ crate::define_cast_evaluator!( #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; @@ -104,13 +102,7 @@ mod test { .timestamp(), ); assert_eq!( - Date64ToCharCastEvaluator { - len: 19, - unit: CharLengthUnits::Characters, - to: LogicalType::Char(19, CharLengthUnits::Characters), - } - .eval_cast(&value) - .unwrap(), + date64_to_char_cast_eval(19, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "2024-01-02 03:04:05".to_string(), ty: Utf8Type::Fixed(19), @@ -118,13 +110,7 @@ mod test { } ); assert_eq!( - Date64ToVarcharCastEvaluator { - len: Some(19), - unit: CharLengthUnits::Characters, - to: LogicalType::Varchar(Some(19), CharLengthUnits::Characters), - } - .eval_cast(&value) - .unwrap(), + date64_to_varchar_cast_eval(Some(19), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "2024-01-02 03:04:05".to_string(), ty: Utf8Type::Variable(Some(19)), @@ -132,11 +118,7 @@ mod test { } ); assert_eq!( - Date64ToDateCastEvaluator { - to: LogicalType::Date - } - .eval_cast(&value) - .unwrap(), + date64_to_date_cast_eval(&value).unwrap(), DataValue::Date32( chrono::NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() @@ -144,21 +126,11 @@ mod test { ) ); assert_eq!( - Date64ToTimeCastEvaluator { - precision: Some(0), - to: LogicalType::Time(Some(0)), - } - .eval_cast(&value) - .unwrap(), + date64_to_time_cast_eval(Some(0), &value).unwrap(), DataValue::Time32(DataValue::pack(3 * 3600 + 4 * 60 + 5, 0, 0), 0) ); assert_eq!( - Date64ToTimestampCastEvaluator { - precision: Some(0), - zone: true, - } - .eval_cast(&value) - .unwrap(), + date64_to_timestamp_cast_eval(Some(0), true, &value).unwrap(), DataValue::Time64( chrono::NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() diff --git a/src/types/evaluator/decimal.rs b/src/types/evaluator/decimal.rs index ab2d5ec9..0b015696 100644 --- a/src/types/evaluator/decimal.rs +++ b/src/types/evaluator/decimal.rs @@ -14,59 +14,36 @@ use crate::errors::DatabaseError; use crate::types::evaluator::cast::{to_char, to_varchar}; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; use rust_decimal::prelude::ToPrimitive; use std::hint; -#[derive(Debug)] -pub struct DecimalPlusBinaryEvaluator; -#[derive(Debug)] -pub struct DecimalMinusBinaryEvaluator; -#[derive(Debug)] -pub struct DecimalMultiplyBinaryEvaluator; -#[derive(Debug)] -pub struct DecimalDivideBinaryEvaluator; -#[derive(Debug)] -pub struct DecimalGtBinaryEvaluator; -#[derive(Debug)] -pub struct DecimalGtEqBinaryEvaluator; -#[derive(Debug)] -pub struct DecimalLtBinaryEvaluator; -#[derive(Debug)] -pub struct DecimalLtEqBinaryEvaluator; -#[derive(Debug)] -pub struct DecimalEqBinaryEvaluator; -#[derive(Debug)] -pub struct DecimalNotEqBinaryEvaluator; -#[derive(Debug)] -pub struct DecimalModBinaryEvaluator; -impl BinaryEvaluator for DecimalPlusBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Decimal(v1 + v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for DecimalMinusBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Decimal(v1 - v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } +macro_rules! decimal_binary { + ($name:ident, $body:expr) => { + pub fn $name(left: &DataValue, right: &DataValue) -> Result { + Ok(match (left, right) { + (DataValue::Decimal(v1), DataValue::Decimal(v2)) => $body(v1, v2), + (DataValue::Decimal(_), DataValue::Null) + | (DataValue::Null, DataValue::Decimal(_)) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) + } + }; } -crate::define_cast_evaluator!(DecimalToFloatCastEvaluator, DataValue::Decimal(value) => { +decimal_binary!( + decimal_plus_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Decimal(v1 + v2) +); +decimal_binary!( + decimal_minus_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Decimal(v1 - v2) +); + +crate::define_cast_evaluator!(decimal_to_float_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::Float32(OrderedFloat(value.to_f32().ok_or_else(|| { crate::types::evaluator::cast::cast_fail( crate::types::LogicalType::Decimal(None, None), @@ -74,7 +51,7 @@ crate::define_cast_evaluator!(DecimalToFloatCastEvaluator, DataValue::Decimal(va ) })?))) }); -crate::define_cast_evaluator!(DecimalToDoubleCastEvaluator, DataValue::Decimal(value) => { +crate::define_cast_evaluator!(decimal_to_double_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::Float64(OrderedFloat(value.to_f64().ok_or_else(|| { crate::types::evaluator::cast::cast_fail( crate::types::LogicalType::Decimal(None, None), @@ -82,151 +59,87 @@ crate::define_cast_evaluator!(DecimalToDoubleCastEvaluator, DataValue::Decimal(v ) })?))) }); -crate::define_cast_evaluator!(DecimalToDecimalCastEvaluator, DataValue::Decimal(value) => { +crate::define_cast_evaluator!(decimal_to_decimal_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::Decimal(*value)) }); crate::define_cast_evaluator!( - DecimalToCharCastEvaluator { + decimal_to_char_cast_eval { len: u32, unit: CharLengthUnits }, DataValue::Decimal(value) => |this| to_char(value.to_string(), this.len, this.unit) ); crate::define_cast_evaluator!( - DecimalToVarcharCastEvaluator { + decimal_to_varchar_cast_eval { len: Option, unit: CharLengthUnits }, DataValue::Decimal(value) => |this| to_varchar(value.to_string(), this.len, this.unit) ); -crate::define_cast_evaluator!(DecimalToTinyintCastEvaluator, DataValue::Decimal(value) => { +crate::define_cast_evaluator!(decimal_to_tinyint_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::Int8(crate::decimal_to_int_cast!(*value, i8))) }); -crate::define_cast_evaluator!(DecimalToSmallintCastEvaluator, DataValue::Decimal(value) => { +crate::define_cast_evaluator!(decimal_to_smallint_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::Int16(crate::decimal_to_int_cast!(*value, i16))) }); -crate::define_cast_evaluator!(DecimalToIntegerCastEvaluator, DataValue::Decimal(value) => { +crate::define_cast_evaluator!(decimal_to_integer_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::Int32(crate::decimal_to_int_cast!(*value, i32))) }); -crate::define_cast_evaluator!(DecimalToBigintCastEvaluator, DataValue::Decimal(value) => { +crate::define_cast_evaluator!(decimal_to_bigint_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::Int64(crate::decimal_to_int_cast!(*value, i64))) }); -crate::define_cast_evaluator!(DecimalToUTinyintCastEvaluator, DataValue::Decimal(value) => { +crate::define_cast_evaluator!(decimal_to_utinyint_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::UInt8(crate::decimal_to_int_cast!(*value, u8))) }); -crate::define_cast_evaluator!(DecimalToUSmallintCastEvaluator, DataValue::Decimal(value) => { +crate::define_cast_evaluator!(decimal_to_usmallint_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::UInt16(crate::decimal_to_int_cast!(*value, u16))) }); -crate::define_cast_evaluator!(DecimalToUIntegerCastEvaluator, DataValue::Decimal(value) => { +crate::define_cast_evaluator!(decimal_to_uinteger_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::UInt32(crate::decimal_to_int_cast!(*value, u32))) }); -crate::define_cast_evaluator!(DecimalToUBigintCastEvaluator, DataValue::Decimal(value) => { +crate::define_cast_evaluator!(decimal_to_ubigint_cast_eval, DataValue::Decimal(value) => { Ok(DataValue::UInt64(crate::decimal_to_int_cast!(*value, u64))) }); -impl BinaryEvaluator for DecimalMultiplyBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Decimal(v1 * v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for DecimalDivideBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Decimal(v1 / v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for DecimalGtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Boolean(v1 > v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for DecimalGtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Boolean(v1 >= v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for DecimalLtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Boolean(v1 < v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for DecimalLtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Boolean(v1 <= v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for DecimalEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Boolean(v1 == v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for DecimalNotEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Boolean(v1 != v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for DecimalModBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Decimal(v1), DataValue::Decimal(v2)) => DataValue::Decimal(v1 % v2), - (DataValue::Decimal(_), DataValue::Null) - | (DataValue::Null, DataValue::Decimal(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} +decimal_binary!( + decimal_multiply_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Decimal(v1 * v2) +); +decimal_binary!( + decimal_divide_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Decimal(v1 / v2) +); +decimal_binary!( + decimal_gt_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Boolean(v1 > v2) +); +decimal_binary!( + decimal_gt_eq_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Boolean(v1 >= v2) +); +decimal_binary!( + decimal_lt_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Boolean(v1 < v2) +); +decimal_binary!( + decimal_lt_eq_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Boolean(v1 <= v2) +); +decimal_binary!( + decimal_eq_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Boolean(v1 == v2) +); +decimal_binary!( + decimal_not_eq_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Boolean(v1 != v2) +); +decimal_binary!( + decimal_mod_binary_eval, + |v1: &rust_decimal::Decimal, v2: &rust_decimal::Decimal| DataValue::Decimal(v1 % v2) +); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; use rust_decimal::Decimal; @@ -236,24 +149,19 @@ mod test { let value = DataValue::Decimal(Decimal::new(125, 1)); assert_eq!( - DecimalToFloatCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(12.5)) ); assert_eq!( - DecimalToDoubleCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(12.5)) ); assert_eq!( - DecimalToDecimalCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_decimal_cast_eval(&value).unwrap(), DataValue::Decimal(Decimal::new(125, 1)) ); assert_eq!( - DecimalToCharCastEvaluator { - len: 4, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + decimal_to_char_cast_eval(4, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "12.5".to_string(), ty: Utf8Type::Fixed(4), @@ -261,12 +169,7 @@ mod test { } ); assert_eq!( - DecimalToVarcharCastEvaluator { - len: Some(4), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + decimal_to_varchar_cast_eval(Some(4), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "12.5".to_string(), ty: Utf8Type::Variable(Some(4)), @@ -274,35 +177,35 @@ mod test { } ); assert_eq!( - DecimalToTinyintCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(12) ); assert_eq!( - DecimalToSmallintCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(12) ); assert_eq!( - DecimalToIntegerCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(12) ); assert_eq!( - DecimalToBigintCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(12) ); assert_eq!( - DecimalToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(12) ); assert_eq!( - DecimalToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(12) ); assert_eq!( - DecimalToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(12) ); assert_eq!( - DecimalToUBigintCastEvaluator.eval_cast(&value).unwrap(), + decimal_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(12) ); } diff --git a/src/types/evaluator/float32.rs b/src/types/evaluator/float32.rs index 3133e724..02e1fda0 100644 --- a/src/types/evaluator/float32.rs +++ b/src/types/evaluator/float32.rs @@ -14,185 +14,112 @@ use crate::errors::DatabaseError; use crate::types::evaluator::DataValue; -use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; +#[cfg(feature = "decimal")] use crate::types::LogicalType; +#[cfg(feature = "decimal")] use rust_decimal::prelude::FromPrimitive; use std::hint; - -#[derive(Debug)] -pub struct Float32PlusUnaryEvaluator; -#[derive(Debug)] -pub struct Float32MinusUnaryEvaluator; -impl UnaryEvaluator for Float32PlusUnaryEvaluator { - fn unary_eval(&self, value: &DataValue) -> DataValue { - value.clone() +pub fn float32_plus_unary_eval(value: &DataValue) -> DataValue { + value.clone() +} +pub fn float32_minus_unary_eval(value: &DataValue) -> DataValue { + match value { + DataValue::Float32(value) => DataValue::Float32(-value), + DataValue::Null => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, } } -impl UnaryEvaluator for Float32MinusUnaryEvaluator { - fn unary_eval(&self, value: &DataValue) -> DataValue { - match value { - DataValue::Float32(value) => DataValue::Float32(-value), - DataValue::Null => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, + +macro_rules! float32_binary { + ($name:ident, $body:expr) => { + pub fn $name(left: &DataValue, right: &DataValue) -> Result { + Ok(match (left, right) { + (DataValue::Float32(v1), DataValue::Float32(v2)) => $body(v1, v2), + (DataValue::Float32(_), DataValue::Null) + | (DataValue::Null, DataValue::Float32(_)) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } - } + }; } -#[derive(Debug)] -pub struct Float32PlusBinaryEvaluator; -#[derive(Debug)] -pub struct Float32MinusBinaryEvaluator; -#[derive(Debug)] -pub struct Float32MultiplyBinaryEvaluator; -#[derive(Debug)] -pub struct Float32DivideBinaryEvaluator; -#[derive(Debug)] -pub struct Float32GtBinaryEvaluator; -#[derive(Debug)] -pub struct Float32GtEqBinaryEvaluator; -#[derive(Debug)] -pub struct Float32LtBinaryEvaluator; -#[derive(Debug)] -pub struct Float32LtEqBinaryEvaluator; -#[derive(Debug)] -pub struct Float32EqBinaryEvaluator; -#[derive(Debug)] -pub struct Float32NotEqBinaryEvaluator; -#[derive(Debug)] -pub struct Float32ModBinaryEvaluator; -impl BinaryEvaluator for Float32PlusBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => DataValue::Float32(*v1 + *v2), - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +float32_binary!( + float32_plus_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Float32(*v1 + *v2) } -} -impl BinaryEvaluator for Float32MinusBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => DataValue::Float32(*v1 - *v2), - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float32_binary!( + float32_minus_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Float32(*v1 - *v2) } -} - -crate::define_float_cast_evaluators!(Float32, Float32, f32, LogicalType::Float, from_f32); -impl BinaryEvaluator for Float32MultiplyBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => DataValue::Float32(*v1 * *v2), - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float32_binary!( + float32_multiply_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Float32(*v1 * *v2) } -} -impl BinaryEvaluator for Float32DivideBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => { - DataValue::Float64(ordered_float::OrderedFloat(**v1 as f64 / **v2 as f64)) - } - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float32_binary!( + float32_divide_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Float64(ordered_float::OrderedFloat(v1.0 as f64 / v2.0 as f64)) } -} -impl BinaryEvaluator for Float32GtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => DataValue::Boolean(v1 > v2), - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float32_binary!( + float32_gt_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 > v2) } -} -impl BinaryEvaluator for Float32GtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => DataValue::Boolean(v1 >= v2), - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float32_binary!( + float32_gt_eq_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 >= v2) } -} -impl BinaryEvaluator for Float32LtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => DataValue::Boolean(v1 < v2), - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float32_binary!( + float32_lt_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 < v2) } -} -impl BinaryEvaluator for Float32LtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => DataValue::Boolean(v1 <= v2), - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float32_binary!( + float32_lt_eq_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 <= v2) } -} -impl BinaryEvaluator for Float32EqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => DataValue::Boolean(v1 == v2), - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float32_binary!( + float32_eq_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 == v2) } -} -impl BinaryEvaluator for Float32NotEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => DataValue::Boolean(v1 != v2), - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float32_binary!( + float32_not_eq_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 != v2) } -} -impl BinaryEvaluator for Float32ModBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float32(v1), DataValue::Float32(v2)) => DataValue::Float32(*v1 % *v2), - (DataValue::Float32(_), DataValue::Null) - | (DataValue::Null, DataValue::Float32(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float32_binary!( + float32_mod_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Float32(*v1 % *v2) } -} +); + +crate::define_float_cast_evaluators!(Float32, Float32, f32, LogicalType::Float, from_f32); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; + #[cfg(feature = "decimal")] use rust_decimal::Decimal; #[test] @@ -200,52 +127,47 @@ mod test { let value = DataValue::Float32(OrderedFloat(1.5)); assert_eq!( - Float32ToFloatCastEvaluator.eval_cast(&value).unwrap(), + float32_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(1.5)) ); assert_eq!( - Float32ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + float32_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(1.5)) ); assert_eq!( - Float32ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + float32_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - Float32ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + float32_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - Float32ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + float32_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - Float32ToBigintCastEvaluator.eval_cast(&value).unwrap(), + float32_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - Float32ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + float32_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - Float32ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + float32_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - Float32ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + float32_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - Float32ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + float32_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - Float32ToCharCastEvaluator { - len: 3, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + float32_to_char_cast_eval(3, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1.5".to_string(), ty: Utf8Type::Fixed(3), @@ -253,25 +175,16 @@ mod test { } ); assert_eq!( - Float32ToVarcharCastEvaluator { - len: Some(3), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + float32_to_varchar_cast_eval(Some(3), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1.5".to_string(), ty: Utf8Type::Variable(Some(3)), unit: CharLengthUnits::Characters, } ); + #[cfg(feature = "decimal")] assert_eq!( - Float32ToDecimalCastEvaluator { - scale: Some(1), - to: LogicalType::Decimal(None, Some(1)), - } - .eval_cast(&value) - .unwrap(), + float32_to_decimal_cast_eval(None, Some(1), &value).unwrap(), DataValue::Decimal(Decimal::new(15, 1)) ); } diff --git a/src/types/evaluator/float64.rs b/src/types/evaluator/float64.rs index c8358c74..3fcb7411 100644 --- a/src/types/evaluator/float64.rs +++ b/src/types/evaluator/float64.rs @@ -13,245 +13,180 @@ // limitations under the License. use crate::errors::DatabaseError; +#[cfg(feature = "decimal")] use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; +#[cfg(not(feature = "decimal"))] +use crate::types::evaluator::cast::{to_char, to_varchar}; use crate::types::evaluator::DataValue; -use crate::types::evaluator::{BinaryEvaluator, UnaryEvaluator}; use crate::types::CharLengthUnits; +#[cfg(feature = "decimal")] use crate::types::LogicalType; +#[cfg(feature = "decimal")] use rust_decimal::prelude::FromPrimitive; +#[cfg(feature = "decimal")] use rust_decimal::Decimal; use std::hint; - -#[derive(Debug)] -pub struct Float64PlusUnaryEvaluator; -#[derive(Debug)] -pub struct Float64MinusUnaryEvaluator; -impl UnaryEvaluator for Float64PlusUnaryEvaluator { - fn unary_eval(&self, value: &DataValue) -> DataValue { - value.clone() +pub fn float64_plus_unary_eval(value: &DataValue) -> DataValue { + value.clone() +} +pub fn float64_minus_unary_eval(value: &DataValue) -> DataValue { + match value { + DataValue::Float64(value) => DataValue::Float64(-value), + DataValue::Null => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, } } -impl UnaryEvaluator for Float64MinusUnaryEvaluator { - fn unary_eval(&self, value: &DataValue) -> DataValue { - match value { - DataValue::Float64(value) => DataValue::Float64(-value), - DataValue::Null => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, + +macro_rules! float64_binary { + ($name:ident, $body:expr) => { + pub fn $name(left: &DataValue, right: &DataValue) -> Result { + Ok(match (left, right) { + (DataValue::Float64(v1), DataValue::Float64(v2)) => $body(v1, v2), + (DataValue::Float64(_), DataValue::Null) + | (DataValue::Null, DataValue::Float64(_)) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } - } + }; } -#[derive(Debug)] -pub struct Float64PlusBinaryEvaluator; -#[derive(Debug)] -pub struct Float64MinusBinaryEvaluator; -#[derive(Debug)] -pub struct Float64MultiplyBinaryEvaluator; -#[derive(Debug)] -pub struct Float64DivideBinaryEvaluator; -#[derive(Debug)] -pub struct Float64GtBinaryEvaluator; -#[derive(Debug)] -pub struct Float64GtEqBinaryEvaluator; -#[derive(Debug)] -pub struct Float64LtBinaryEvaluator; -#[derive(Debug)] -pub struct Float64LtEqBinaryEvaluator; -#[derive(Debug)] -pub struct Float64EqBinaryEvaluator; -#[derive(Debug)] -pub struct Float64NotEqBinaryEvaluator; -#[derive(Debug)] -pub struct Float64ModBinaryEvaluator; -impl BinaryEvaluator for Float64PlusBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => DataValue::Float64(*v1 + *v2), - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +float64_binary!( + float64_plus_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Float64(*v1 + *v2) } -} -impl BinaryEvaluator for Float64MinusBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => DataValue::Float64(*v1 - *v2), - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float64_binary!( + float64_minus_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Float64(*v1 - *v2) } -} +); -crate::define_cast_evaluator!(Float64ToFloatCastEvaluator, DataValue::Float64(value) => { +crate::define_cast_evaluator!(float64_to_float_cast_eval, DataValue::Float64(value) => { Ok(DataValue::Float32(ordered_float::OrderedFloat(value.0 as f32))) }); -crate::define_cast_evaluator!(Float64ToDoubleCastEvaluator, DataValue::Float64(value) => { +crate::define_cast_evaluator!(float64_to_double_cast_eval, DataValue::Float64(value) => { Ok(DataValue::Float64(*value)) }); -crate::define_cast_evaluator!(Float64ToTinyintCastEvaluator, DataValue::Float64(value) => { +crate::define_cast_evaluator!(float64_to_tinyint_cast_eval, DataValue::Float64(value) => { Ok(DataValue::Int8(crate::float_to_int_cast!(value.into_inner(), i8, f64)?)) }); -crate::define_cast_evaluator!(Float64ToSmallintCastEvaluator, DataValue::Float64(value) => { +crate::define_cast_evaluator!(float64_to_smallint_cast_eval, DataValue::Float64(value) => { Ok(DataValue::Int16(crate::float_to_int_cast!(value.into_inner(), i16, f64)?)) }); -crate::define_cast_evaluator!(Float64ToIntegerCastEvaluator, DataValue::Float64(value) => { +crate::define_cast_evaluator!(float64_to_integer_cast_eval, DataValue::Float64(value) => { Ok(DataValue::Int32(crate::float_to_int_cast!(value.into_inner(), i32, f64)?)) }); -crate::define_cast_evaluator!(Float64ToBigintCastEvaluator, DataValue::Float64(value) => { +crate::define_cast_evaluator!(float64_to_bigint_cast_eval, DataValue::Float64(value) => { Ok(DataValue::Int64(crate::float_to_int_cast!(value.into_inner(), i64, f64)?)) }); -crate::define_cast_evaluator!(Float64ToUTinyintCastEvaluator, DataValue::Float64(value) => { +crate::define_cast_evaluator!(float64_to_utinyint_cast_eval, DataValue::Float64(value) => { Ok(DataValue::UInt8(crate::float_to_int_cast!(value.into_inner(), u8, f64)?)) }); -crate::define_cast_evaluator!(Float64ToUSmallintCastEvaluator, DataValue::Float64(value) => { +crate::define_cast_evaluator!(float64_to_usmallint_cast_eval, DataValue::Float64(value) => { Ok(DataValue::UInt16(crate::float_to_int_cast!(value.into_inner(), u16, f64)?)) }); -crate::define_cast_evaluator!(Float64ToUIntegerCastEvaluator, DataValue::Float64(value) => { +crate::define_cast_evaluator!(float64_to_uinteger_cast_eval, DataValue::Float64(value) => { Ok(DataValue::UInt32(crate::float_to_int_cast!(value.into_inner(), u32, f64)?)) }); -crate::define_cast_evaluator!(Float64ToUBigintCastEvaluator, DataValue::Float64(value) => { +crate::define_cast_evaluator!(float64_to_ubigint_cast_eval, DataValue::Float64(value) => { Ok(DataValue::UInt64(crate::float_to_int_cast!(value.into_inner(), u64, f64)?)) }); crate::define_cast_evaluator!( - Float64ToCharCastEvaluator { + float64_to_char_cast_eval { len: u32, unit: CharLengthUnits }, DataValue::Float64(value) => |this| to_char(value.to_string(), this.len, this.unit) ); crate::define_cast_evaluator!( - Float64ToVarcharCastEvaluator { + float64_to_varchar_cast_eval { len: Option, unit: CharLengthUnits }, DataValue::Float64(value) => |this| to_varchar(value.to_string(), this.len, this.unit) ); +#[cfg(feature = "decimal")] crate::define_cast_evaluator!( - Float64ToDecimalCastEvaluator { + float64_to_decimal_cast_eval { + precision: Option, scale: Option, - to: LogicalType }, DataValue::Float64(value) => |this| { let mut decimal = Decimal::from_f64(value.0).ok_or_else(|| { - cast_fail(LogicalType::Double, this.to.clone()) + cast_fail( + LogicalType::Double, + LogicalType::Decimal(this.precision, this.scale), + ) })?; DataValue::decimal_round_f(&this.scale, &mut decimal); Ok(DataValue::Decimal(decimal)) } ); -impl BinaryEvaluator for Float64MultiplyBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => DataValue::Float64(*v1 * *v2), - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +float64_binary!( + float64_multiply_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Float64(*v1 * *v2) } -} -impl BinaryEvaluator for Float64DivideBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => { - DataValue::Float64(ordered_float::OrderedFloat(**v1 / **v2)) - } - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float64_binary!( + float64_divide_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Float64(ordered_float::OrderedFloat(v1.0 / v2.0)) } -} -impl BinaryEvaluator for Float64GtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => DataValue::Boolean(v1 > v2), - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float64_binary!( + float64_gt_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 > v2) } -} -impl BinaryEvaluator for Float64GtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => DataValue::Boolean(v1 >= v2), - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float64_binary!( + float64_gt_eq_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 >= v2) } -} -impl BinaryEvaluator for Float64LtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => DataValue::Boolean(v1 < v2), - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float64_binary!( + float64_lt_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 < v2) } -} -impl BinaryEvaluator for Float64LtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => DataValue::Boolean(v1 <= v2), - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float64_binary!( + float64_lt_eq_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 <= v2) } -} -impl BinaryEvaluator for Float64EqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => DataValue::Boolean(v1 == v2), - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float64_binary!( + float64_eq_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 == v2) } -} -impl BinaryEvaluator for Float64NotEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => DataValue::Boolean(v1 != v2), - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float64_binary!( + float64_not_eq_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Boolean(v1 != v2) } -} -impl BinaryEvaluator for Float64ModBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Float64(v1), DataValue::Float64(v2)) => DataValue::Float64(*v1 % *v2), - (DataValue::Float64(_), DataValue::Null) - | (DataValue::Null, DataValue::Float64(_)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +); +float64_binary!( + float64_mod_binary_eval, + |v1: &ordered_float::OrderedFloat, v2: &ordered_float::OrderedFloat| { + DataValue::Float64(*v1 % *v2) } -} +); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; + #[cfg(feature = "decimal")] use rust_decimal::Decimal; #[test] @@ -259,61 +194,55 @@ mod test { let value = DataValue::Float64(ordered_float::OrderedFloat(1.5)); assert_eq!( - Float64MultiplyBinaryEvaluator - .binary_eval( - &DataValue::Float64(ordered_float::OrderedFloat(1.5)), - &DataValue::Float64(ordered_float::OrderedFloat(2.0)), - ) - .unwrap(), + float64_multiply_binary_eval( + &DataValue::Float64(ordered_float::OrderedFloat(1.5)), + &DataValue::Float64(ordered_float::OrderedFloat(2.0)), + ) + .unwrap(), DataValue::Float64(ordered_float::OrderedFloat(3.0)) ); assert_eq!( - Float64ToFloatCastEvaluator.eval_cast(&value).unwrap(), + float64_to_float_cast_eval(&value).unwrap(), DataValue::Float32(ordered_float::OrderedFloat(1.5)) ); assert_eq!( - Float64ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + float64_to_double_cast_eval(&value).unwrap(), DataValue::Float64(ordered_float::OrderedFloat(1.5)) ); assert_eq!( - Float64ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + float64_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - Float64ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + float64_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - Float64ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + float64_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - Float64ToBigintCastEvaluator.eval_cast(&value).unwrap(), + float64_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - Float64ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + float64_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - Float64ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + float64_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - Float64ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + float64_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - Float64ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + float64_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - Float64ToCharCastEvaluator { - len: 3, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + float64_to_char_cast_eval(3, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1.5".to_string(), ty: Utf8Type::Fixed(3), @@ -321,25 +250,16 @@ mod test { } ); assert_eq!( - Float64ToVarcharCastEvaluator { - len: Some(3), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + float64_to_varchar_cast_eval(Some(3), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1.5".to_string(), ty: Utf8Type::Variable(Some(3)), unit: CharLengthUnits::Characters, } ); + #[cfg(feature = "decimal")] assert_eq!( - Float64ToDecimalCastEvaluator { - scale: Some(1), - to: LogicalType::Decimal(None, Some(1)), - } - .eval_cast(&value) - .unwrap(), + float64_to_decimal_cast_eval(None, Some(1), &value).unwrap(), DataValue::Decimal(Decimal::new(15, 1)) ); } diff --git a/src/types/evaluator/int16.rs b/src/types/evaluator/int16.rs index 09759419..207623ab 100644 --- a/src/types/evaluator/int16.rs +++ b/src/types/evaluator/int16.rs @@ -23,7 +23,6 @@ crate::define_integer_cast_evaluators!(Int16, Int16, i16, LogicalType::Smallint) #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; @@ -34,56 +33,51 @@ mod test { let value = DataValue::Int16(1); assert_eq!( - Int16ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + int16_to_boolean_cast_eval(&value).unwrap(), DataValue::Boolean(true) ); assert_eq!( - Int16ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + int16_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - Int16ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + int16_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - Int16ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + int16_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - Int16ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + int16_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - Int16ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + int16_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - Int16ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + int16_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - Int16ToBigintCastEvaluator.eval_cast(&value).unwrap(), + int16_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - Int16ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + int16_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - Int16ToFloatCastEvaluator.eval_cast(&value).unwrap(), + int16_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(1.0)) ); assert_eq!( - Int16ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + int16_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(1.0)) ); assert_eq!( - Int16ToCharCastEvaluator { - len: 1, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + int16_to_char_cast_eval(1, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Fixed(1), @@ -91,12 +85,7 @@ mod test { } ); assert_eq!( - Int16ToVarcharCastEvaluator { - len: Some(1), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + int16_to_varchar_cast_eval(Some(1), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Variable(Some(1)), @@ -104,9 +93,7 @@ mod test { } ); assert_eq!( - Int16ToDecimalCastEvaluator { scale: Some(2) } - .eval_cast(&value) - .unwrap(), + int16_to_decimal_cast_eval(Some(2), &value).unwrap(), DataValue::Decimal(Decimal::new(100, 2)) ); } diff --git a/src/types/evaluator/int32.rs b/src/types/evaluator/int32.rs index 36471f8b..b8e6e342 100644 --- a/src/types/evaluator/int32.rs +++ b/src/types/evaluator/int32.rs @@ -23,7 +23,6 @@ crate::define_integer_cast_evaluators!(Int32, Int32, i32, LogicalType::Integer); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; @@ -32,27 +31,19 @@ mod test { #[test] fn test_int32_binary_evaluators() { assert_eq!( - Int32PlusBinaryEvaluator - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1)) - .unwrap(), + int32_plus_binary_eval(&DataValue::Int32(1), &DataValue::Int32(1)).unwrap(), DataValue::Int32(2) ); assert_eq!( - Int32MinusBinaryEvaluator - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1)) - .unwrap(), + int32_minus_binary_eval(&DataValue::Int32(1), &DataValue::Int32(1)).unwrap(), DataValue::Int32(0) ); assert_eq!( - Int32EqBinaryEvaluator - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(1)) - .unwrap(), + int32_eq_binary_eval(&DataValue::Int32(1), &DataValue::Int32(1)).unwrap(), DataValue::Boolean(true) ); assert_eq!( - Int32GtBinaryEvaluator - .binary_eval(&DataValue::Int32(1), &DataValue::Int32(0)) - .unwrap(), + int32_gt_binary_eval(&DataValue::Int32(1), &DataValue::Int32(0)).unwrap(), DataValue::Boolean(true) ); } @@ -62,56 +53,51 @@ mod test { let value = DataValue::Int32(1); assert_eq!( - Int32ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + int32_to_boolean_cast_eval(&value).unwrap(), DataValue::Boolean(true) ); assert_eq!( - Int32ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + int32_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - Int32ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + int32_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - Int32ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + int32_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - Int32ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + int32_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - Int32ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + int32_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - Int32ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + int32_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - Int32ToBigintCastEvaluator.eval_cast(&value).unwrap(), + int32_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - Int32ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + int32_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - Int32ToFloatCastEvaluator.eval_cast(&value).unwrap(), + int32_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(1.0)) ); assert_eq!( - Int32ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + int32_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(1.0)) ); assert_eq!( - Int32ToCharCastEvaluator { - len: 1, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + int32_to_char_cast_eval(1, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Fixed(1), @@ -119,12 +105,7 @@ mod test { } ); assert_eq!( - Int32ToVarcharCastEvaluator { - len: Some(1), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + int32_to_varchar_cast_eval(Some(1), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Variable(Some(1)), @@ -132,9 +113,7 @@ mod test { } ); assert_eq!( - Int32ToDecimalCastEvaluator { scale: Some(1) } - .eval_cast(&value) - .unwrap(), + int32_to_decimal_cast_eval(Some(1), &value).unwrap(), DataValue::Decimal(Decimal::new(10, 1)) ); } diff --git a/src/types/evaluator/int64.rs b/src/types/evaluator/int64.rs index 3aace64b..f213d30c 100644 --- a/src/types/evaluator/int64.rs +++ b/src/types/evaluator/int64.rs @@ -23,7 +23,6 @@ crate::define_integer_cast_evaluators!(Int64, Int64, i64, LogicalType::Bigint); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; @@ -34,56 +33,51 @@ mod test { let value = DataValue::Int64(1); assert_eq!( - Int64ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + int64_to_boolean_cast_eval(&value).unwrap(), DataValue::Boolean(true) ); assert_eq!( - Int64ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + int64_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - Int64ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + int64_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - Int64ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + int64_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - Int64ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + int64_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - Int64ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + int64_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - Int64ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + int64_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - Int64ToBigintCastEvaluator.eval_cast(&value).unwrap(), + int64_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - Int64ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + int64_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - Int64ToFloatCastEvaluator.eval_cast(&value).unwrap(), + int64_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(1.0)) ); assert_eq!( - Int64ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + int64_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(1.0)) ); assert_eq!( - Int64ToCharCastEvaluator { - len: 1, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + int64_to_char_cast_eval(1, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Fixed(1), @@ -91,12 +85,7 @@ mod test { } ); assert_eq!( - Int64ToVarcharCastEvaluator { - len: Some(1), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + int64_to_varchar_cast_eval(Some(1), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Variable(Some(1)), @@ -104,9 +93,7 @@ mod test { } ); assert_eq!( - Int64ToDecimalCastEvaluator { scale: Some(2) } - .eval_cast(&value) - .unwrap(), + int64_to_decimal_cast_eval(Some(2), &value).unwrap(), DataValue::Decimal(Decimal::new(100, 2)) ); } diff --git a/src/types/evaluator/int8.rs b/src/types/evaluator/int8.rs index adb29aa9..78915bff 100644 --- a/src/types/evaluator/int8.rs +++ b/src/types/evaluator/int8.rs @@ -23,7 +23,6 @@ crate::define_integer_cast_evaluators!(Int8, Int8, i8, LogicalType::Tinyint); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; @@ -34,56 +33,51 @@ mod test { let value = DataValue::Int8(1); assert_eq!( - Int8ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + int8_to_boolean_cast_eval(&value).unwrap(), DataValue::Boolean(true) ); assert_eq!( - Int8ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + int8_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - Int8ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + int8_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - Int8ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + int8_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - Int8ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + int8_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - Int8ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + int8_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - Int8ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + int8_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - Int8ToBigintCastEvaluator.eval_cast(&value).unwrap(), + int8_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - Int8ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + int8_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - Int8ToFloatCastEvaluator.eval_cast(&value).unwrap(), + int8_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(1.0)) ); assert_eq!( - Int8ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + int8_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(1.0)) ); assert_eq!( - Int8ToCharCastEvaluator { - len: 1, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + int8_to_char_cast_eval(1, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Fixed(1), @@ -91,12 +85,7 @@ mod test { } ); assert_eq!( - Int8ToVarcharCastEvaluator { - len: Some(1), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + int8_to_varchar_cast_eval(Some(1), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Variable(Some(1)), @@ -104,9 +93,7 @@ mod test { } ); assert_eq!( - Int8ToDecimalCastEvaluator { scale: Some(2) } - .eval_cast(&value) - .unwrap(), + int8_to_decimal_cast_eval(Some(2), &value).unwrap(), DataValue::Decimal(Decimal::new(100, 2)) ); } diff --git a/src/types/evaluator/mod.rs b/src/types/evaluator/mod.rs index 1583df11..e37be249 100644 --- a/src/types/evaluator/mod.rs +++ b/src/types/evaluator/mod.rs @@ -15,8 +15,11 @@ pub mod binary; pub mod boolean; pub mod cast; +#[cfg(feature = "time")] pub mod date; +#[cfg(feature = "time")] pub mod datetime; +#[cfg(feature = "decimal")] pub mod decimal; pub mod float32; pub mod float64; @@ -25,7 +28,9 @@ pub mod int32; pub mod int64; pub mod int8; pub mod null; +#[cfg(feature = "time")] pub mod time32; +#[cfg(feature = "time")] pub mod time64; pub mod tuple; pub mod uint16; @@ -40,44 +45,24 @@ pub use self::cast::cast_create; pub use self::unary::unary_create; use crate::errors::DatabaseError; -use crate::expression::{BinaryOperator, UnaryOperator}; use crate::types::value::DataValue; -use crate::types::LogicalType; -use std::fmt::Debug; -use std::hash::{Hash, Hasher}; -use std::ops::Deref; -use std::sync::Arc; +use crate::types::CharLengthUnits; -pub trait BinaryEvaluator: Send + Sync + Debug { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result; +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum BinaryEvaluatorParams { + Unit, + Like { escape_char: Option }, } -pub trait UnaryEvaluator: Send + Sync + Debug { - fn unary_eval(&self, value: &DataValue) -> DataValue; +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct BinaryEvaluatorRef { + pub pos: u16, + pub params: BinaryEvaluatorParams, } -pub trait CastEvaluator: Send + Sync + Debug { - fn eval_cast(&self, value: &DataValue) -> Result; -} - -#[derive(Clone, Debug)] -pub struct BinaryEvaluatorBox { - pub evaluator: Arc, - pub ty: LogicalType, - pub op: BinaryOperator, -} - -impl Deref for BinaryEvaluatorBox { - type Target = dyn BinaryEvaluator; - - fn deref(&self) -> &Self::Target { - self.evaluator.as_ref() - } -} - -impl BinaryEvaluatorBox { - pub fn new(evaluator: Arc, ty: LogicalType, op: BinaryOperator) -> Self { - Self { evaluator, ty, op } +impl BinaryEvaluatorRef { + pub fn new(pos: u16, params: BinaryEvaluatorParams) -> Self { + Self { pos, params } } pub fn binary_eval( @@ -85,97 +70,62 @@ impl BinaryEvaluatorBox { left: &DataValue, right: &DataValue, ) -> Result { - self.evaluator.binary_eval(left, right) + binary::eval_binary(self.pos, &self.params, left, right) } } -#[derive(Clone, Debug)] -pub struct UnaryEvaluatorBox { - pub evaluator: Arc, - pub ty: LogicalType, - pub op: UnaryOperator, +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct UnaryEvaluatorRef { + pub pos: u16, } -impl UnaryEvaluatorBox { - pub fn new(evaluator: Arc, ty: LogicalType, op: UnaryOperator) -> Self { - Self { evaluator, ty, op } +impl UnaryEvaluatorRef { + pub fn new(pos: u16) -> Self { + Self { pos } } pub fn unary_eval(&self, value: &DataValue) -> DataValue { - self.evaluator.unary_eval(value) - } -} - -#[derive(Clone, Debug)] -pub struct CastEvaluatorBox { - pub evaluator: Arc, - pub from: LogicalType, - pub to: LogicalType, -} - -impl Deref for CastEvaluatorBox { - type Target = dyn CastEvaluator; - - fn deref(&self) -> &Self::Target { - self.evaluator.as_ref() + unary::eval_unary(self.pos, value) } } -impl CastEvaluatorBox { - pub fn new(evaluator: Arc, from: LogicalType, to: LogicalType) -> Self { - Self { - evaluator, - from, - to, - } +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct CastEvaluatorRef { + pub pos: u16, + pub params: CastEvaluatorParams, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum CastEvaluatorParams { + Identity, + Unit, + String { + len: Option, + unit: CharLengthUnits, + }, + #[cfg(feature = "decimal")] + Decimal { + precision: Option, + scale: Option, + }, + Precision { + precision: Option, + }, + Timestamp { + precision: Option, + zone: bool, + }, + Tuple { + evaluators: Vec, + }, +} + +impl CastEvaluatorRef { + pub fn new(pos: u16, params: CastEvaluatorParams) -> Self { + Self { pos, params } } - pub fn eval_cast(&self, value: &DataValue) -> Result { - self.evaluator.eval_cast(value) - } -} - -impl PartialEq for BinaryEvaluatorBox { - fn eq(&self, _: &Self) -> bool { - // FIXME - true - } -} - -impl Eq for BinaryEvaluatorBox {} - -impl Hash for BinaryEvaluatorBox { - fn hash(&self, state: &mut H) { - state.write_i8(42) - } -} - -impl PartialEq for UnaryEvaluatorBox { - fn eq(&self, _: &Self) -> bool { - // FIXME - true - } -} - -impl Eq for UnaryEvaluatorBox {} - -impl Hash for UnaryEvaluatorBox { - fn hash(&self, state: &mut H) { - state.write_i8(42) - } -} - -impl PartialEq for CastEvaluatorBox { - fn eq(&self, _: &Self) -> bool { - // FIXME - true - } -} - -impl Eq for CastEvaluatorBox {} - -impl Hash for CastEvaluatorBox { - fn hash(&self, state: &mut H) { - state.write_i8(42) + pub fn pos(&self) -> u16 { + self.pos } } diff --git a/src/types/evaluator/null.rs b/src/types/evaluator/null.rs index e53918c9..c92a7a42 100644 --- a/src/types/evaluator/null.rs +++ b/src/types/evaluator/null.rs @@ -14,32 +14,17 @@ use crate::errors::DatabaseError; use crate::types::evaluator::DataValue; -use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; /// Tips: /// - Null values operate as null values -#[derive(Debug)] -pub struct NullBinaryEvaluator; -impl BinaryEvaluator for NullBinaryEvaluator { - fn binary_eval(&self, _: &DataValue, _: &DataValue) -> Result { - Ok(DataValue::Null) - } +pub fn null_binary_eval(_: &DataValue, _: &DataValue) -> Result { + Ok(DataValue::Null) } - -#[derive(Debug)] -pub struct ToSqlNullCastEvaluator; -impl CastEvaluator for ToSqlNullCastEvaluator { - fn eval_cast(&self, _value: &DataValue) -> Result { - Ok(DataValue::Null) - } +pub fn to_sql_null_cast_eval(_value: &DataValue) -> Result { + Ok(DataValue::Null) } - -#[derive(Debug)] -pub struct NullCastEvaluator; -impl CastEvaluator for NullCastEvaluator { - fn eval_cast(&self, _value: &DataValue) -> Result { - Ok(DataValue::Null) - } +pub fn null_cast_eval(_value: &DataValue) -> Result { + Ok(DataValue::Null) } #[cfg(all(test, not(target_arch = "wasm32")))] @@ -49,14 +34,9 @@ mod test { #[test] fn test_null_cast_evaluators() { assert_eq!( - ToSqlNullCastEvaluator - .eval_cast(&DataValue::Int32(1)) - .unwrap(), - DataValue::Null - ); - assert_eq!( - NullCastEvaluator.eval_cast(&DataValue::Null).unwrap(), + to_sql_null_cast_eval(&DataValue::Int32(1)).unwrap(), DataValue::Null ); + assert_eq!(null_cast_eval(&DataValue::Null).unwrap(), DataValue::Null); } } diff --git a/src/types/evaluator/time32.rs b/src/types/evaluator/time32.rs index 18866ffc..3375c8e1 100644 --- a/src/types/evaluator/time32.rs +++ b/src/types/evaluator/time32.rs @@ -14,117 +14,94 @@ use crate::errors::DatabaseError; use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; use crate::types::value::{ONE_DAY_TO_SEC, ONE_SEC_TO_NANO}; use crate::types::CharLengthUnits; use crate::types::LogicalType; use std::hint; - -#[derive(Debug)] -pub struct TimePlusBinaryEvaluator; -#[derive(Debug)] -pub struct TimeMinusBinaryEvaluator; -#[derive(Debug)] -pub struct TimeGtBinaryEvaluator; -#[derive(Debug)] -pub struct TimeGtEqBinaryEvaluator; -#[derive(Debug)] -pub struct TimeLtBinaryEvaluator; -#[derive(Debug)] -pub struct TimeLtEqBinaryEvaluator; -#[derive(Debug)] -pub struct TimeEqBinaryEvaluator; -#[derive(Debug)] -pub struct TimeNotEqBinaryEvaluator; -impl BinaryEvaluator for TimePlusBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2)) => { - let (mut v1, n1) = DataValue::unpack(*v1, *p1); - let (v2, n2) = DataValue::unpack(*v2, *p2); - let mut n = n1 + n2; - while n > ONE_SEC_TO_NANO { - v1 += 1; - n -= ONE_SEC_TO_NANO; - } - let p = if p2 > p1 { *p2 } else { *p1 }; - if v1 + v2 > ONE_DAY_TO_SEC { - return Ok(DataValue::Null); - } - DataValue::Time32(DataValue::pack(v1 + v2, n, p), p) +pub fn time_plus_binary_eval( + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match (left, right) { + (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2)) => { + let (mut v1, n1) = DataValue::unpack(*v1, *p1); + let (v2, n2) = DataValue::unpack(*v2, *p2); + let mut n = n1 + n2; + while n > ONE_SEC_TO_NANO { + v1 += 1; + n -= ONE_SEC_TO_NANO; } - (DataValue::Time32(..), DataValue::Null) - | (DataValue::Null, DataValue::Time32(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for TimeMinusBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2, ..)) => { - let (mut v1, mut n1) = DataValue::unpack(*v1, *p1); - let (v2, n2) = DataValue::unpack(*v2, *p2); - while n1 < n2 { - v1 -= 1; - n1 += ONE_SEC_TO_NANO; - } - if v1 < v2 { - return Ok(DataValue::Null); - } - let p = if p2 > p1 { *p2 } else { *p1 }; - DataValue::Time32(DataValue::pack(v1 - v2, n1 - n2, p), p) + let p = if p2 > p1 { *p2 } else { *p1 }; + if v1 + v2 > ONE_DAY_TO_SEC { + return Ok(DataValue::Null); } - (DataValue::Time32(..), DataValue::Null) - | (DataValue::Null, DataValue::Time32(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } + DataValue::Time32(DataValue::pack(v1 + v2, n, p), p) + } + (DataValue::Time32(..), DataValue::Null) + | (DataValue::Null, DataValue::Time32(..)) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } -impl BinaryEvaluator for TimeGtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2, ..)) => { - let (v1, n1) = DataValue::unpack(*v1, *p1); - let (v2, n2) = DataValue::unpack(*v2, *p2); - DataValue::Boolean(v1.cmp(&v2).then_with(|| n1.cmp(&n2)).is_gt()) +pub fn time_minus_binary_eval( + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match (left, right) { + (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2, ..)) => { + let (mut v1, mut n1) = DataValue::unpack(*v1, *p1); + let (v2, n2) = DataValue::unpack(*v2, *p2); + while n1 < n2 { + v1 -= 1; + n1 += ONE_SEC_TO_NANO; } - (DataValue::Time32(..), DataValue::Null) - | (DataValue::Null, DataValue::Time32(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for TimeGtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2, ..)) => { - let (v1, n1) = DataValue::unpack(*v1, *p1); - let (v2, n2) = DataValue::unpack(*v2, *p2); - DataValue::Boolean(!v1.cmp(&v2).then_with(|| n1.cmp(&n2)).is_lt()) + if v1 < v2 { + return Ok(DataValue::Null); } - (DataValue::Time32(..), DataValue::Null) - | (DataValue::Null, DataValue::Time32(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } + let p = if p2 > p1 { *p2 } else { *p1 }; + DataValue::Time32(DataValue::pack(v1 - v2, n1 - n2, p), p) + } + (DataValue::Time32(..), DataValue::Null) + | (DataValue::Null, DataValue::Time32(..)) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } +macro_rules! time32_order_binary { + ($name:ident, $is_order:ident) => { + pub fn $name(left: &DataValue, right: &DataValue) -> Result { + Ok(match (left, right) { + (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2, ..)) => { + let (v1, n1) = DataValue::unpack(*v1, *p1); + let (v2, n2) = DataValue::unpack(*v2, *p2); + DataValue::Boolean(v1.cmp(&v2).then_with(|| n1.cmp(&n2)).$is_order()) + } + (DataValue::Time32(..), DataValue::Null) + | (DataValue::Null, DataValue::Time32(..)) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) + } + }; +} + +time32_order_binary!(time_gt_binary_eval, is_gt); +time32_order_binary!(time_gt_eq_binary_eval, is_ge); + crate::define_cast_evaluator!( - Time32ToCharCastEvaluator { + time32_to_char_cast_eval { len: u32, - unit: CharLengthUnits, - to: LogicalType + unit: CharLengthUnits }, DataValue::Time32(value, precision) => |this| { to_char( DataValue::format_time(*value, *precision).ok_or_else(|| { - cast_fail(LogicalType::Time(Some(*precision)), this.to.clone()) + cast_fail( + LogicalType::Time(Some(*precision)), + LogicalType::Char(this.len, this.unit), + ) })?, this.len, this.unit, @@ -132,15 +109,17 @@ crate::define_cast_evaluator!( } ); crate::define_cast_evaluator!( - Time32ToVarcharCastEvaluator { + time32_to_varchar_cast_eval { len: Option, - unit: CharLengthUnits, - to: LogicalType + unit: CharLengthUnits }, DataValue::Time32(value, precision) => |this| { to_varchar( DataValue::format_time(*value, *precision).ok_or_else(|| { - cast_fail(LogicalType::Time(Some(*precision)), this.to.clone()) + cast_fail( + LogicalType::Time(Some(*precision)), + LogicalType::Varchar(this.len, this.unit), + ) })?, this.len, this.unit, @@ -148,98 +127,47 @@ crate::define_cast_evaluator!( } ); crate::define_cast_evaluator!( - Time32ToTimeCastEvaluator { + time32_to_time_cast_eval { precision: Option }, DataValue::Time32(value, _precision) => |this| { Ok(DataValue::Time32(*value, this.precision.unwrap_or(0))) } ); -impl BinaryEvaluator for TimeLtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2, ..)) => { - let (v1, n1) = DataValue::unpack(*v1, *p1); - let (v2, n2) = DataValue::unpack(*v2, *p2); - DataValue::Boolean(v1.cmp(&v2).then_with(|| n1.cmp(&n2)).is_lt()) - } - (DataValue::Time32(..), DataValue::Null) - | (DataValue::Null, DataValue::Time32(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for TimeLtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2, ..)) => { - let (v1, n1) = DataValue::unpack(*v1, *p1); - let (v2, n2) = DataValue::unpack(*v2, *p2); - DataValue::Boolean(!v1.cmp(&v2).then_with(|| n1.cmp(&n2)).is_gt()) - } - (DataValue::Time32(..), DataValue::Null) - | (DataValue::Null, DataValue::Time32(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for TimeEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2, ..)) => { - let (v1, n1) = DataValue::unpack(*v1, *p1); - let (v2, n2) = DataValue::unpack(*v2, *p2); - DataValue::Boolean(v1.cmp(&v2).then_with(|| n1.cmp(&n2)).is_eq()) - } - (DataValue::Time32(..), DataValue::Null) - | (DataValue::Null, DataValue::Time32(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for TimeNotEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time32(v1, p1), DataValue::Time32(v2, p2, ..)) => { - let (v1, n1) = DataValue::unpack(*v1, *p1); - let (v2, n2) = DataValue::unpack(*v2, *p2); - DataValue::Boolean(!v1.cmp(&v2).then_with(|| n1.cmp(&n2)).is_eq()) - } - (DataValue::Time32(..), DataValue::Null) - | (DataValue::Null, DataValue::Time32(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } +time32_order_binary!(time_lt_binary_eval, is_lt); +time32_order_binary!(time_lt_eq_binary_eval, is_le); +time32_order_binary!(time_eq_binary_eval, is_eq); +pub fn time_not_eq_binary_eval( + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match time_eq_binary_eval(left, right)? { + DataValue::Boolean(value) => DataValue::Boolean(!value), + value => value, + }) } #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; use crate::types::value::Utf8Type; #[test] fn test_time32_binary_evaluators() { assert_eq!( - TimePlusBinaryEvaluator - .binary_eval( - &DataValue::Time32(4_190_119_896, 3), - &DataValue::Time32(2_621_204_256, 4), - ) - .unwrap(), + time_plus_binary_eval( + &DataValue::Time32(4_190_119_896, 3), + &DataValue::Time32(2_621_204_256, 4), + ) + .unwrap(), DataValue::Time32(2_618_593_017, 4) ); assert_eq!( - TimeGtBinaryEvaluator - .binary_eval( - &DataValue::Time32(2_621_204_256, 4), - &DataValue::Time32(4_190_119_896, 3), - ) - .unwrap(), + time_gt_binary_eval( + &DataValue::Time32(2_621_204_256, 4), + &DataValue::Time32(4_190_119_896, 3), + ) + .unwrap(), DataValue::Boolean(true) ); } @@ -248,13 +176,7 @@ mod test { fn test_time32_cast_evaluators() { let value = DataValue::Time32(DataValue::pack(3 * 3600 + 4 * 60 + 5, 123_000_000, 3), 3); assert_eq!( - Time32ToCharCastEvaluator { - len: 12, - unit: CharLengthUnits::Characters, - to: LogicalType::Char(12, CharLengthUnits::Characters), - } - .eval_cast(&value) - .unwrap(), + time32_to_char_cast_eval(12, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "03:04:05.123".to_string(), ty: Utf8Type::Fixed(12), @@ -262,24 +184,13 @@ mod test { } ); assert_eq!( - Time32ToVarcharCastEvaluator { - len: Some(12), - unit: CharLengthUnits::Characters, - to: LogicalType::Varchar(Some(12), CharLengthUnits::Characters), - } - .eval_cast(&value) - .unwrap(), + time32_to_varchar_cast_eval(Some(12), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "03:04:05.123".to_string(), ty: Utf8Type::Variable(Some(12)), unit: CharLengthUnits::Characters, } ); - assert_eq!( - Time32ToTimeCastEvaluator { precision: Some(3) } - .eval_cast(&value) - .unwrap(), - value - ); + assert_eq!(time32_to_time_cast_eval(Some(3), &value).unwrap(), value); } } diff --git a/src/types/evaluator/time64.rs b/src/types/evaluator/time64.rs index 3468f40c..cd9e2f7c 100644 --- a/src/types/evaluator/time64.rs +++ b/src/types/evaluator/time64.rs @@ -14,84 +14,54 @@ use crate::errors::DatabaseError; use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; -use crate::types::evaluator::BinaryEvaluator; use crate::types::evaluator::DataValue; use crate::types::CharLengthUnits; use crate::types::LogicalType; use chrono::{Datelike, Timelike}; use std::hint; -#[derive(Debug)] -pub struct Time64GtBinaryEvaluator; -#[derive(Debug)] -pub struct Time64GtEqBinaryEvaluator; -#[derive(Debug)] -pub struct Time64LtBinaryEvaluator; -#[derive(Debug)] -pub struct Time64LtEqBinaryEvaluator; -#[derive(Debug)] -pub struct Time64EqBinaryEvaluator; -#[derive(Debug)] -pub struct Time64NotEqBinaryEvaluator; -impl BinaryEvaluator for Time64GtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time64(v1, p1, _), DataValue::Time64(v2, p2, _)) => { - if let (Some(v1), Some(v2)) = ( - DataValue::from_timestamp_precision(*v1, *p1), - DataValue::from_timestamp_precision(*v2, *p2), - ) { - let p = if p2 > p1 { *p2 } else { *p1 }; - DataValue::Boolean( - DataValue::timestamp_precision(v1, p) - > DataValue::timestamp_precision(v2, p), - ) - } else { - DataValue::Null +macro_rules! time64_binary { + ($name:ident, $op:tt) => { + pub fn $name(left: &DataValue, right: &DataValue) -> Result { + Ok(match (left, right) { + (DataValue::Time64(v1, p1, _), DataValue::Time64(v2, p2, _)) => { + if let (Some(v1), Some(v2)) = ( + DataValue::from_timestamp_precision(*v1, *p1), + DataValue::from_timestamp_precision(*v2, *p2), + ) { + let p = if p2 > p1 { *p2 } else { *p1 }; + DataValue::Boolean( + DataValue::timestamp_precision(v1, p) + $op DataValue::timestamp_precision(v2, p), + ) + } else { + DataValue::Null + } } - } - (DataValue::Time64(..), DataValue::Null) - | (DataValue::Null, DataValue::Time64(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for Time64GtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time64(v1, p1, _), DataValue::Time64(v2, p2, _)) => { - if let (Some(v1), Some(v2)) = ( - DataValue::from_timestamp_precision(*v1, *p1), - DataValue::from_timestamp_precision(*v2, *p2), - ) { - let p = if p2 > p1 { *p2 } else { *p1 }; - DataValue::Boolean( - DataValue::timestamp_precision(v1, p) - >= DataValue::timestamp_precision(v2, p), - ) - } else { - DataValue::Null - } - } - (DataValue::Time64(..), DataValue::Null) - | (DataValue::Null, DataValue::Time64(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } + (DataValue::Time64(..), DataValue::Null) + | (DataValue::Null, DataValue::Time64(..)) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) + } + }; } +time64_binary!(time64_gt_binary_eval, >); +time64_binary!(time64_gt_eq_binary_eval, >=); + crate::define_cast_evaluator!( - Time64ToCharCastEvaluator { + time64_to_char_cast_eval { len: u32, - unit: CharLengthUnits, - to: LogicalType + unit: CharLengthUnits }, - DataValue::Time64(value, precision, _) => |this| { + DataValue::Time64(value, precision, zone) => |this| { to_char( DataValue::format_timestamp(*value, *precision).ok_or_else(|| { - cast_fail(LogicalType::TimeStamp(Some(*precision), false), this.to.clone()) + cast_fail( + LogicalType::TimeStamp(Some(*precision), *zone), + LogicalType::Char(this.len, this.unit), + ) })?, this.len, this.unit, @@ -99,15 +69,17 @@ crate::define_cast_evaluator!( } ); crate::define_cast_evaluator!( - Time64ToVarcharCastEvaluator { + time64_to_varchar_cast_eval { len: Option, - unit: CharLengthUnits, - to: LogicalType + unit: CharLengthUnits }, - DataValue::Time64(value, precision, _) => |this| { + DataValue::Time64(value, precision, zone) => |this| { to_varchar( DataValue::format_timestamp(*value, *precision).ok_or_else(|| { - cast_fail(LogicalType::TimeStamp(Some(*precision), false), this.to.clone()) + cast_fail( + LogicalType::TimeStamp(Some(*precision), *zone), + LogicalType::Varchar(this.len, this.unit), + ) })?, this.len, this.unit, @@ -115,13 +87,15 @@ crate::define_cast_evaluator!( } ); crate::define_cast_evaluator!( - Time64ToDateCastEvaluator { - from: LogicalType, - to: LogicalType - }, - DataValue::Time64(value, precision, _) => |this| { + time64_to_date_cast_eval, + DataValue::Time64(value, precision, zone) => { let value = DataValue::from_timestamp_precision(*value, *precision) - .ok_or_else(|| cast_fail(this.from.clone(), this.to.clone()))? + .ok_or_else(|| { + cast_fail( + LogicalType::TimeStamp(Some(*precision), *zone), + LogicalType::Date, + ) + })? .naive_utc() .date() .num_days_from_ce(); @@ -130,25 +104,25 @@ crate::define_cast_evaluator!( } ); crate::define_cast_evaluator!( - Time64ToDatetimeCastEvaluator { - from: LogicalType, - to: LogicalType - }, - DataValue::Time64(value, precision, _) => |this| { + time64_to_datetime_cast_eval, + DataValue::Time64(value, precision, zone) => { let value = DataValue::from_timestamp_precision(*value, *precision) - .ok_or_else(|| cast_fail(this.from.clone(), this.to.clone()))? + .ok_or_else(|| { + cast_fail( + LogicalType::TimeStamp(Some(*precision), *zone), + LogicalType::DateTime, + ) + })? .timestamp(); Ok(DataValue::Date64(value)) } ); crate::define_cast_evaluator!( - Time64ToTimeCastEvaluator { - precision: Option, - from: LogicalType, - to: LogicalType + time64_to_time_cast_eval { + precision: Option }, - DataValue::Time64(value, precision, _) => |this| { + DataValue::Time64(value, precision, zone) => |this| { let target_precision = this.precision.unwrap_or(0); let (value, nano) = DataValue::from_timestamp_precision(*value, *precision) .map(|date_time| { @@ -157,7 +131,12 @@ crate::define_cast_evaluator!( date_time.time().nanosecond(), ) }) - .ok_or_else(|| cast_fail(this.from.clone(), this.to.clone()))?; + .ok_or_else(|| { + cast_fail( + LogicalType::TimeStamp(Some(*precision), *zone), + LogicalType::Time(this.precision), + ) + })?; Ok(DataValue::Time32( DataValue::pack(value, nano, target_precision), @@ -166,7 +145,7 @@ crate::define_cast_evaluator!( } ); crate::define_cast_evaluator!( - Time64ToTimestampCastEvaluator { + time64_to_timestamp_cast_eval { precision: Option, zone: bool }, @@ -174,119 +153,25 @@ crate::define_cast_evaluator!( Ok(DataValue::Time64(*value, this.precision.unwrap_or(0), this.zone)) } ); -impl BinaryEvaluator for Time64LtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time64(v1, p1, _), DataValue::Time64(v2, p2, _)) => { - if let (Some(v1), Some(v2)) = ( - DataValue::from_timestamp_precision(*v1, *p1), - DataValue::from_timestamp_precision(*v2, *p2), - ) { - let p = if p2 > p1 { *p2 } else { *p1 }; - DataValue::Boolean( - DataValue::timestamp_precision(v1, p) - < DataValue::timestamp_precision(v2, p), - ) - } else { - DataValue::Null - } - } - (DataValue::Time64(..), DataValue::Null) - | (DataValue::Null, DataValue::Time64(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for Time64LtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time64(v1, p1, _), DataValue::Time64(v2, p2, _)) => { - if let (Some(v1), Some(v2)) = ( - DataValue::from_timestamp_precision(*v1, *p1), - DataValue::from_timestamp_precision(*v2, *p2), - ) { - let p = if p2 > p1 { *p2 } else { *p1 }; - DataValue::Boolean( - DataValue::timestamp_precision(v1, p) - <= DataValue::timestamp_precision(v2, p), - ) - } else { - DataValue::Null - } - } - (DataValue::Time64(..), DataValue::Null) - | (DataValue::Null, DataValue::Time64(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for Time64EqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time64(v1, p1, _), DataValue::Time64(v2, p2, _)) => { - if let (Some(v1), Some(v2)) = ( - DataValue::from_timestamp_precision(*v1, *p1), - DataValue::from_timestamp_precision(*v2, *p2), - ) { - let p = if p2 > p1 { *p2 } else { *p1 }; - DataValue::Boolean( - DataValue::timestamp_precision(v1, p) - == DataValue::timestamp_precision(v2, p), - ) - } else { - DataValue::Null - } - } - (DataValue::Time64(..), DataValue::Null) - | (DataValue::Null, DataValue::Time64(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for Time64NotEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Time64(v1, p1, _), DataValue::Time64(v2, p2, _)) => { - if let (Some(v1), Some(v2)) = ( - DataValue::from_timestamp_precision(*v1, *p1), - DataValue::from_timestamp_precision(*v2, *p2), - ) { - let p = if p2 > p1 { *p2 } else { *p1 }; - DataValue::Boolean( - DataValue::timestamp_precision(v1, p) - != DataValue::timestamp_precision(v2, p), - ) - } else { - DataValue::Null - } - } - (DataValue::Time64(..), DataValue::Null) - | (DataValue::Null, DataValue::Time64(..)) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} +time64_binary!(time64_lt_binary_eval, <); +time64_binary!(time64_lt_eq_binary_eval, <=); +time64_binary!(time64_eq_binary_eval, ==); +time64_binary!(time64_not_eq_binary_eval, !=); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; #[test] fn test_time64_binary_evaluators() { assert_eq!( - Time64EqBinaryEvaluator - .binary_eval( - &DataValue::Time64(1_738_734_177_256, 3, false), - &DataValue::Time64(1_738_734_177_256_000, 6, false), - ) - .unwrap(), + time64_eq_binary_eval( + &DataValue::Time64(1_738_734_177_256, 3, false), + &DataValue::Time64(1_738_734_177_256_000, 6, false), + ) + .unwrap(), DataValue::Boolean(true) ); } @@ -301,13 +186,7 @@ mod test { .timestamp_millis(); let value = DataValue::Time64(timestamp, 3, false); assert_eq!( - Time64ToCharCastEvaluator { - len: 23, - unit: CharLengthUnits::Characters, - to: LogicalType::Char(23, CharLengthUnits::Characters), - } - .eval_cast(&value) - .unwrap(), + time64_to_char_cast_eval(23, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "2024-01-02 03:04:05.123".to_string(), ty: Utf8Type::Fixed(23), @@ -315,13 +194,7 @@ mod test { } ); assert_eq!( - Time64ToVarcharCastEvaluator { - len: Some(23), - unit: CharLengthUnits::Characters, - to: LogicalType::Varchar(Some(23), CharLengthUnits::Characters), - } - .eval_cast(&value) - .unwrap(), + time64_to_varchar_cast_eval(Some(23), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "2024-01-02 03:04:05.123".to_string(), ty: Utf8Type::Variable(Some(23)), @@ -329,12 +202,7 @@ mod test { } ); assert_eq!( - Time64ToDateCastEvaluator { - from: LogicalType::TimeStamp(Some(3), false), - to: LogicalType::Date, - } - .eval_cast(&value) - .unwrap(), + time64_to_date_cast_eval(&value).unwrap(), DataValue::Date32( chrono::NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() @@ -342,12 +210,7 @@ mod test { ) ); assert_eq!( - Time64ToDatetimeCastEvaluator { - from: LogicalType::TimeStamp(Some(3), false), - to: LogicalType::DateTime, - } - .eval_cast(&value) - .unwrap(), + time64_to_datetime_cast_eval(&value).unwrap(), DataValue::Date64( chrono::NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() @@ -358,22 +221,11 @@ mod test { ) ); assert_eq!( - Time64ToTimeCastEvaluator { - precision: Some(3), - from: LogicalType::TimeStamp(Some(3), false), - to: LogicalType::Time(Some(3)), - } - .eval_cast(&value) - .unwrap(), + time64_to_time_cast_eval(Some(3), &value).unwrap(), DataValue::Time32(DataValue::pack(3 * 3600 + 4 * 60 + 5, 123_000_000, 3), 3) ); assert_eq!( - Time64ToTimestampCastEvaluator { - precision: Some(3), - zone: true, - } - .eval_cast(&value) - .unwrap(), + time64_to_timestamp_cast_eval(Some(3), true, &value).unwrap(), DataValue::Time64(timestamp, 3, true) ); } diff --git a/src/types/evaluator/tuple.rs b/src/types/evaluator/tuple.rs index 4dd843be..d63b7011 100644 --- a/src/types/evaluator/tuple.rs +++ b/src/types/evaluator/tuple.rs @@ -13,24 +13,11 @@ // limitations under the License. use crate::errors::DatabaseError; +use crate::types::evaluator::CastEvaluatorRef; use crate::types::evaluator::DataValue; -use crate::types::evaluator::{BinaryEvaluator, CastEvaluator, CastEvaluatorBox}; use std::cmp::Ordering; use std::hint; -#[derive(Debug)] -pub struct TupleEqBinaryEvaluator; -#[derive(Debug)] -pub struct TupleNotEqBinaryEvaluator; -#[derive(Debug)] -pub struct TupleGtBinaryEvaluator; -#[derive(Debug)] -pub struct TupleGtEqBinaryEvaluator; -#[derive(Debug)] -pub struct TupleLtBinaryEvaluator; -#[derive(Debug)] -pub struct TupleLtEqBinaryEvaluator; - fn tuple_cmp( (v1, v1_is_upper): (&Vec, &bool), (v2, v2_is_upper): (&Vec, &bool), @@ -61,112 +48,72 @@ fn tuple_cmp( } Some(order) } -impl BinaryEvaluator for TupleEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Tuple(v1, ..), DataValue::Tuple(v2, ..)) => DataValue::Boolean(*v1 == *v2), - (DataValue::Null, DataValue::Boolean(_)) - | (DataValue::Boolean(_), DataValue::Null) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } +pub fn tuple_eq_binary_eval( + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match (left, right) { + (DataValue::Tuple(v1, ..), DataValue::Tuple(v2, ..)) => DataValue::Boolean(*v1 == *v2), + (DataValue::Null, DataValue::Boolean(_)) + | (DataValue::Boolean(_), DataValue::Null) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } -impl BinaryEvaluator for TupleNotEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Tuple(v1, ..), DataValue::Tuple(v2, ..)) => DataValue::Boolean(*v1 != *v2), - (DataValue::Null, DataValue::Boolean(_)) - | (DataValue::Boolean(_), DataValue::Null) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } +pub fn tuple_not_eq_binary_eval( + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match (left, right) { + (DataValue::Tuple(v1, ..), DataValue::Tuple(v2, ..)) => DataValue::Boolean(*v1 != *v2), + (DataValue::Null, DataValue::Boolean(_)) + | (DataValue::Boolean(_), DataValue::Null) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } -#[derive(Debug)] -pub struct TupleCastEvaluator { - pub element_evaluators: Vec, +macro_rules! tuple_order_binary { + ($name:ident, $is_order:ident) => { + pub fn $name(left: &DataValue, right: &DataValue) -> Result { + Ok(match (left, right) { + (DataValue::Tuple(v1, is_upper1), DataValue::Tuple(v2, is_upper2)) => { + tuple_cmp((v1, is_upper1), (v2, is_upper2)) + .map(|order| DataValue::Boolean(order.$is_order())) + .unwrap_or(DataValue::Null) + } + (DataValue::Null, DataValue::Boolean(_)) + | (DataValue::Boolean(_), DataValue::Null) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) + } + }; } -impl CastEvaluator for TupleCastEvaluator { - fn eval_cast(&self, value: &DataValue) -> Result { - match value { - DataValue::Null => Ok(DataValue::Null), - DataValue::Tuple(values, is_upper) => { - let mut casted = Vec::with_capacity(values.len()); +tuple_order_binary!(tuple_gt_binary_eval, is_gt); +tuple_order_binary!(tuple_gt_eq_binary_eval, is_ge); +tuple_order_binary!(tuple_lt_binary_eval, is_lt); +tuple_order_binary!(tuple_lt_eq_binary_eval, is_le); - for (value, evaluator) in values.iter().zip(self.element_evaluators.iter()) { - casted.push(evaluator.eval_cast(value)?); - } +pub(crate) fn eval_tuple_cast( + element_evaluators: &[CastEvaluatorRef], + value: &DataValue, +) -> Result { + match value { + DataValue::Null => Ok(DataValue::Null), + DataValue::Tuple(values, is_upper) => { + let mut casted = Vec::with_capacity(values.len()); - Ok(DataValue::Tuple(casted, *is_upper)) + for (value, evaluator) in values.iter().zip(element_evaluators.iter()) { + casted.push(evaluator.eval(value)?); } - _ => unsafe { hint::unreachable_unchecked() }, + + Ok(DataValue::Tuple(casted, *is_upper)) } + _ => unsafe { hint::unreachable_unchecked() }, } } -impl BinaryEvaluator for TupleGtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Tuple(v1, is_upper1), DataValue::Tuple(v2, is_upper2)) => { - tuple_cmp((v1, is_upper1), (v2, is_upper2)) - .map(|order| DataValue::Boolean(order.is_gt())) - .unwrap_or(DataValue::Null) - } - (DataValue::Null, DataValue::Boolean(_)) - | (DataValue::Boolean(_), DataValue::Null) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for TupleGtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Tuple(v1, is_upper1), DataValue::Tuple(v2, is_upper2)) => { - tuple_cmp((v1, is_upper1), (v2, is_upper2)) - .map(|order| DataValue::Boolean(order.is_ge())) - .unwrap_or(DataValue::Null) - } - (DataValue::Null, DataValue::Boolean(_)) - | (DataValue::Boolean(_), DataValue::Null) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for TupleLtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Tuple(v1, is_upper1), DataValue::Tuple(v2, is_upper2)) => { - tuple_cmp((v1, is_upper1), (v2, is_upper2)) - .map(|order| DataValue::Boolean(order.is_lt())) - .unwrap_or(DataValue::Null) - } - (DataValue::Null, DataValue::Boolean(_)) - | (DataValue::Boolean(_), DataValue::Null) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for TupleLtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Tuple(v1, is_upper1), DataValue::Tuple(v2, is_upper2)) => { - tuple_cmp((v1, is_upper1), (v2, is_upper2)) - .map(|order| DataValue::Boolean(order.is_le())) - .unwrap_or(DataValue::Null) - } - (DataValue::Null, DataValue::Boolean(_)) - | (DataValue::Boolean(_), DataValue::Null) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} - #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; @@ -176,7 +123,7 @@ mod test { use std::borrow::Cow; #[test] - fn test_tuple_cast_evaluator() { + fn test_tuple_cast_eval() { let evaluator = cast_create( Cow::Owned(LogicalType::Tuple(vec![ LogicalType::Integer, @@ -191,7 +138,7 @@ mod test { assert_eq!( evaluator - .eval_cast(&DataValue::Tuple( + .eval(&DataValue::Tuple( vec![ DataValue::Int32(1), DataValue::Utf8 { diff --git a/src/types/evaluator/uint16.rs b/src/types/evaluator/uint16.rs index 332013e2..237e864f 100644 --- a/src/types/evaluator/uint16.rs +++ b/src/types/evaluator/uint16.rs @@ -16,13 +16,12 @@ use crate::numeric_binary_evaluator_definition; use crate::types::evaluator::DataValue; use crate::types::LogicalType; -numeric_binary_evaluator_definition!(UInt16, DataValue::UInt16); -crate::define_integer_cast_evaluators!(UInt16, UInt16, u16, LogicalType::USmallint); +numeric_binary_evaluator_definition!(Uint16, DataValue::UInt16); +crate::define_integer_cast_evaluators!(Uint16, UInt16, u16, LogicalType::USmallint); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; @@ -33,56 +32,51 @@ mod test { let value = DataValue::UInt16(1); assert_eq!( - UInt16ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_boolean_cast_eval(&value).unwrap(), DataValue::Boolean(true) ); assert_eq!( - UInt16ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - UInt16ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - UInt16ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - UInt16ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - UInt16ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - UInt16ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - UInt16ToBigintCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - UInt16ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - UInt16ToFloatCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(1.0)) ); assert_eq!( - UInt16ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + uint16_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(1.0)) ); assert_eq!( - UInt16ToCharCastEvaluator { - len: 1, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + uint16_to_char_cast_eval(1, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Fixed(1), @@ -90,12 +84,7 @@ mod test { } ); assert_eq!( - UInt16ToVarcharCastEvaluator { - len: Some(1), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + uint16_to_varchar_cast_eval(Some(1), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Variable(Some(1)), @@ -103,9 +92,7 @@ mod test { } ); assert_eq!( - UInt16ToDecimalCastEvaluator { scale: Some(2) } - .eval_cast(&value) - .unwrap(), + uint16_to_decimal_cast_eval(Some(2), &value).unwrap(), DataValue::Decimal(Decimal::new(100, 2)) ); } diff --git a/src/types/evaluator/uint32.rs b/src/types/evaluator/uint32.rs index 9bc08089..d3f04b8a 100644 --- a/src/types/evaluator/uint32.rs +++ b/src/types/evaluator/uint32.rs @@ -16,13 +16,12 @@ use crate::numeric_binary_evaluator_definition; use crate::types::evaluator::DataValue; use crate::types::LogicalType; -numeric_binary_evaluator_definition!(UInt32, DataValue::UInt32); -crate::define_integer_cast_evaluators!(UInt32, UInt32, u32, LogicalType::UInteger); +numeric_binary_evaluator_definition!(Uint32, DataValue::UInt32); +crate::define_integer_cast_evaluators!(Uint32, UInt32, u32, LogicalType::UInteger); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; @@ -33,56 +32,51 @@ mod test { let value = DataValue::UInt32(1); assert_eq!( - UInt32ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_boolean_cast_eval(&value).unwrap(), DataValue::Boolean(true) ); assert_eq!( - UInt32ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - UInt32ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - UInt32ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - UInt32ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - UInt32ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - UInt32ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - UInt32ToBigintCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - UInt32ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - UInt32ToFloatCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(1.0)) ); assert_eq!( - UInt32ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + uint32_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(1.0)) ); assert_eq!( - UInt32ToCharCastEvaluator { - len: 1, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + uint32_to_char_cast_eval(1, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Fixed(1), @@ -90,12 +84,7 @@ mod test { } ); assert_eq!( - UInt32ToVarcharCastEvaluator { - len: Some(1), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + uint32_to_varchar_cast_eval(Some(1), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Variable(Some(1)), @@ -103,9 +92,7 @@ mod test { } ); assert_eq!( - UInt32ToDecimalCastEvaluator { scale: Some(2) } - .eval_cast(&value) - .unwrap(), + uint32_to_decimal_cast_eval(Some(2), &value).unwrap(), DataValue::Decimal(Decimal::new(100, 2)) ); } diff --git a/src/types/evaluator/uint64.rs b/src/types/evaluator/uint64.rs index 5b2a64d3..3f912e46 100644 --- a/src/types/evaluator/uint64.rs +++ b/src/types/evaluator/uint64.rs @@ -16,13 +16,12 @@ use crate::numeric_binary_evaluator_definition; use crate::types::evaluator::DataValue; use crate::types::LogicalType; -numeric_binary_evaluator_definition!(UInt64, DataValue::UInt64); -crate::define_integer_cast_evaluators!(UInt64, UInt64, u64, LogicalType::UBigint); +numeric_binary_evaluator_definition!(Uint64, DataValue::UInt64); +crate::define_integer_cast_evaluators!(Uint64, UInt64, u64, LogicalType::UBigint); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; @@ -33,56 +32,51 @@ mod test { let value = DataValue::UInt64(1); assert_eq!( - UInt64ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_boolean_cast_eval(&value).unwrap(), DataValue::Boolean(true) ); assert_eq!( - UInt64ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - UInt64ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - UInt64ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - UInt64ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - UInt64ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - UInt64ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - UInt64ToBigintCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - UInt64ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - UInt64ToFloatCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(1.0)) ); assert_eq!( - UInt64ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + uint64_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(1.0)) ); assert_eq!( - UInt64ToCharCastEvaluator { - len: 1, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + uint64_to_char_cast_eval(1, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Fixed(1), @@ -90,12 +84,7 @@ mod test { } ); assert_eq!( - UInt64ToVarcharCastEvaluator { - len: Some(1), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + uint64_to_varchar_cast_eval(Some(1), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Variable(Some(1)), @@ -103,9 +92,7 @@ mod test { } ); assert_eq!( - UInt64ToDecimalCastEvaluator { scale: Some(2) } - .eval_cast(&value) - .unwrap(), + uint64_to_decimal_cast_eval(Some(2), &value).unwrap(), DataValue::Decimal(Decimal::new(100, 2)) ); } diff --git a/src/types/evaluator/uint8.rs b/src/types/evaluator/uint8.rs index 355278f9..fda45e21 100644 --- a/src/types/evaluator/uint8.rs +++ b/src/types/evaluator/uint8.rs @@ -16,13 +16,12 @@ use crate::numeric_binary_evaluator_definition; use crate::types::evaluator::DataValue; use crate::types::LogicalType; -numeric_binary_evaluator_definition!(UInt8, DataValue::UInt8); -crate::define_integer_cast_evaluators!(UInt8, UInt8, u8, LogicalType::UTinyint); +numeric_binary_evaluator_definition!(Uint8, DataValue::UInt8); +crate::define_integer_cast_evaluators!(Uint8, UInt8, u8, LogicalType::UTinyint); #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::CastEvaluator; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; use ordered_float::OrderedFloat; @@ -33,56 +32,51 @@ mod test { let value = DataValue::UInt8(1); assert_eq!( - UInt8ToBooleanCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_boolean_cast_eval(&value).unwrap(), DataValue::Boolean(true) ); assert_eq!( - UInt8ToTinyintCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_tinyint_cast_eval(&value).unwrap(), DataValue::Int8(1) ); assert_eq!( - UInt8ToUTinyintCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_utinyint_cast_eval(&value).unwrap(), DataValue::UInt8(1) ); assert_eq!( - UInt8ToSmallintCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_smallint_cast_eval(&value).unwrap(), DataValue::Int16(1) ); assert_eq!( - UInt8ToUSmallintCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_usmallint_cast_eval(&value).unwrap(), DataValue::UInt16(1) ); assert_eq!( - UInt8ToIntegerCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_integer_cast_eval(&value).unwrap(), DataValue::Int32(1) ); assert_eq!( - UInt8ToUIntegerCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_uinteger_cast_eval(&value).unwrap(), DataValue::UInt32(1) ); assert_eq!( - UInt8ToBigintCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_bigint_cast_eval(&value).unwrap(), DataValue::Int64(1) ); assert_eq!( - UInt8ToUBigintCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_ubigint_cast_eval(&value).unwrap(), DataValue::UInt64(1) ); assert_eq!( - UInt8ToFloatCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_float_cast_eval(&value).unwrap(), DataValue::Float32(OrderedFloat(1.0)) ); assert_eq!( - UInt8ToDoubleCastEvaluator.eval_cast(&value).unwrap(), + uint8_to_double_cast_eval(&value).unwrap(), DataValue::Float64(OrderedFloat(1.0)) ); assert_eq!( - UInt8ToCharCastEvaluator { - len: 1, - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + uint8_to_char_cast_eval(1, CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Fixed(1), @@ -90,12 +84,7 @@ mod test { } ); assert_eq!( - UInt8ToVarcharCastEvaluator { - len: Some(1), - unit: CharLengthUnits::Characters, - } - .eval_cast(&value) - .unwrap(), + uint8_to_varchar_cast_eval(Some(1), CharLengthUnits::Characters, &value).unwrap(), DataValue::Utf8 { value: "1".to_string(), ty: Utf8Type::Variable(Some(1)), @@ -103,9 +92,7 @@ mod test { } ); assert_eq!( - UInt8ToDecimalCastEvaluator { scale: Some(2) } - .eval_cast(&value) - .unwrap(), + uint8_to_decimal_cast_eval(Some(2), &value).unwrap(), DataValue::Decimal(Decimal::new(100, 2)) ); } diff --git a/src/types/evaluator/unary.rs b/src/types/evaluator/unary.rs index 1bea7e42..d2f7c805 100644 --- a/src/types/evaluator/unary.rs +++ b/src/types/evaluator/unary.rs @@ -14,79 +14,106 @@ use crate::errors::DatabaseError; use crate::expression::UnaryOperator; -use crate::types::evaluator::boolean::BooleanNotUnaryEvaluator; +use crate::types::evaluator::boolean::boolean_not_unary_eval; use crate::types::evaluator::float32::*; use crate::types::evaluator::float64::*; use crate::types::evaluator::int16::*; use crate::types::evaluator::int32::*; use crate::types::evaluator::int64::*; use crate::types::evaluator::int8::*; -use crate::types::evaluator::UnaryEvaluatorBox; +use crate::types::evaluator::UnaryEvaluatorRef; use crate::types::LogicalType; -use paste::paste; use std::borrow::Cow; -use std::sync::Arc; - -macro_rules! box_unary { - ($ty:expr, $op:expr, $evaluator:expr) => { - Ok(UnaryEvaluatorBox::new( - Arc::new($evaluator), - $ty.clone(), - $op, - )) - }; -} -macro_rules! numeric_unary_evaluator { - ($value_type:ident, $op:expr, $ty:expr) => { - paste! { - match $op { - UnaryOperator::Plus => box_unary!($ty, $op, [<$value_type PlusUnaryEvaluator>]), - UnaryOperator::Minus => box_unary!($ty, $op, [<$value_type MinusUnaryEvaluator>]), - _ => Err(DatabaseError::UnsupportedUnaryOperator($ty.clone(), $op)), - } - } +const UNARY_INT8_PLUS: u16 = 0; +const UNARY_INT8_MINUS: u16 = 1; +const UNARY_INT16_PLUS: u16 = 2; +const UNARY_INT16_MINUS: u16 = 3; +const UNARY_INT32_PLUS: u16 = 4; +const UNARY_INT32_MINUS: u16 = 5; +const UNARY_INT64_PLUS: u16 = 6; +const UNARY_INT64_MINUS: u16 = 7; +const UNARY_BOOLEAN_NOT: u16 = 8; +const UNARY_FLOAT32_PLUS: u16 = 9; +const UNARY_FLOAT32_MINUS: u16 = 10; +const UNARY_FLOAT64_PLUS: u16 = 11; +const UNARY_FLOAT64_MINUS: u16 = 12; + +// Evaluator positions are serialized ABI. Do not reorder or reuse existing +// positions; only append new positions at the end of the current layout. + +fn numeric_unary_ref( + plus: u16, + minus: u16, + ty: &LogicalType, + op: UnaryOperator, +) -> Result { + let pos = match op { + UnaryOperator::Plus => plus, + UnaryOperator::Minus => minus, + _ => return Err(DatabaseError::UnsupportedUnaryOperator(ty.clone(), op)), }; + Ok(UnaryEvaluatorRef::new(pos)) } pub fn unary_create( ty: Cow<'_, LogicalType>, op: UnaryOperator, -) -> Result { +) -> Result { let ty = ty.as_ref(); match ty { - LogicalType::Tinyint => numeric_unary_evaluator!(Int8, op, ty), - LogicalType::Smallint => numeric_unary_evaluator!(Int16, op, ty), - LogicalType::Integer => numeric_unary_evaluator!(Int32, op, ty), - LogicalType::Bigint => numeric_unary_evaluator!(Int64, op, ty), + LogicalType::Tinyint => numeric_unary_ref(UNARY_INT8_PLUS, UNARY_INT8_MINUS, ty, op), + LogicalType::Smallint => numeric_unary_ref(UNARY_INT16_PLUS, UNARY_INT16_MINUS, ty, op), + LogicalType::Integer => numeric_unary_ref(UNARY_INT32_PLUS, UNARY_INT32_MINUS, ty, op), + LogicalType::Bigint => numeric_unary_ref(UNARY_INT64_PLUS, UNARY_INT64_MINUS, ty, op), LogicalType::Boolean => match op { - UnaryOperator::Not => box_unary!(ty, op, BooleanNotUnaryEvaluator), + UnaryOperator::Not => Ok(UnaryEvaluatorRef::new(UNARY_BOOLEAN_NOT)), _ => Err(DatabaseError::UnsupportedUnaryOperator(ty.clone(), op)), }, - LogicalType::Float => numeric_unary_evaluator!(Float32, op, ty), - LogicalType::Double => numeric_unary_evaluator!(Float64, op, ty), + LogicalType::Float => numeric_unary_ref(UNARY_FLOAT32_PLUS, UNARY_FLOAT32_MINUS, ty, op), + LogicalType::Double => numeric_unary_ref(UNARY_FLOAT64_PLUS, UNARY_FLOAT64_MINUS, ty, op), _ => Err(DatabaseError::UnsupportedUnaryOperator(ty.clone(), op)), } } +pub(crate) fn eval_unary( + pos: u16, + value: &crate::types::value::DataValue, +) -> crate::types::value::DataValue { + match pos { + UNARY_INT8_PLUS => int8_plus_unary_eval(value), + UNARY_INT8_MINUS => int8_minus_unary_eval(value), + UNARY_INT16_PLUS => int16_plus_unary_eval(value), + UNARY_INT16_MINUS => int16_minus_unary_eval(value), + UNARY_INT32_PLUS => int32_plus_unary_eval(value), + UNARY_INT32_MINUS => int32_minus_unary_eval(value), + UNARY_INT64_PLUS => int64_plus_unary_eval(value), + UNARY_INT64_MINUS => int64_minus_unary_eval(value), + UNARY_BOOLEAN_NOT => boolean_not_unary_eval(value), + UNARY_FLOAT32_PLUS => float32_plus_unary_eval(value), + UNARY_FLOAT32_MINUS => float32_minus_unary_eval(value), + UNARY_FLOAT64_PLUS => float64_plus_unary_eval(value), + UNARY_FLOAT64_MINUS => float64_minus_unary_eval(value), + _ => unreachable!("unknown unary evaluator position {pos}"), + } +} + #[macro_export] macro_rules! numeric_unary_evaluator_definition { ($value_type:ident, $compute_type:path) => { paste::paste! { - #[derive(Debug)] - pub struct [<$value_type PlusUnaryEvaluator>]; - #[derive(Debug)] - pub struct [<$value_type MinusUnaryEvaluator>]; impl $crate::types::evaluator::UnaryEvaluator for [<$value_type PlusUnaryEvaluator>] { - fn unary_eval(&self, value: &$crate::types::value::DataValue) -> $crate::types::value::DataValue { - value.clone() - } - } impl $crate::types::evaluator::UnaryEvaluator for [<$value_type MinusUnaryEvaluator>] { - fn unary_eval(&self, value: &$crate::types::value::DataValue) -> $crate::types::value::DataValue { - match value { - $compute_type(value) => $compute_type(-value), - $crate::types::value::DataValue::Null => $crate::types::value::DataValue::Null, - _ => unsafe { std::hint::unreachable_unchecked() }, - } + pub fn [<$value_type:snake _plus_unary_eval>]( + value: &$crate::types::value::DataValue, + ) -> $crate::types::value::DataValue { + value.clone() + } + pub fn [<$value_type:snake _minus_unary_eval>]( + value: &$crate::types::value::DataValue, + ) -> $crate::types::value::DataValue { + match value { + $compute_type(value) => $compute_type(-value), + $crate::types::value::DataValue::Null => $crate::types::value::DataValue::Null, + _ => unsafe { std::hint::unreachable_unchecked() }, } } } @@ -95,22 +122,40 @@ macro_rules! numeric_unary_evaluator_definition { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { - use super::unary_create; + use super::*; use crate::errors::DatabaseError; use crate::expression::UnaryOperator; use crate::serdes::{ReferenceSerialization, ReferenceTables}; use crate::storage::rocksdb::RocksTransaction; - use crate::types::evaluator::UnaryEvaluatorBox; + use crate::types::evaluator::UnaryEvaluatorRef; use crate::types::value::DataValue; use crate::types::LogicalType; use ordered_float::OrderedFloat; use std::borrow::Cow; use std::io::{Cursor, Seek, SeekFrom}; - fn create(ty: LogicalType, op: UnaryOperator) -> Result { + fn create(ty: LogicalType, op: UnaryOperator) -> Result { unary_create(Cow::Owned(ty), op) } + #[test] + fn test_unary_evaluator_positions_are_stable() -> Result<(), DatabaseError> { + assert_eq!( + create(LogicalType::Integer, UnaryOperator::Minus)?.pos, + UNARY_INT32_MINUS + ); + assert_eq!( + create(LogicalType::Boolean, UnaryOperator::Not)?.pos, + UNARY_BOOLEAN_NOT + ); + assert_eq!( + create(LogicalType::Double, UnaryOperator::Plus)?.pos, + UNARY_FLOAT64_PLUS + ); + + Ok(()) + } + #[test] fn test_numeric_unary_evaluators() -> Result<(), DatabaseError> { let cases = vec![ @@ -150,7 +195,7 @@ mod test { } #[test] - fn test_boolean_unary_evaluator() -> Result<(), DatabaseError> { + fn test_boolean_unary_eval() -> Result<(), DatabaseError> { let evaluator = create(LogicalType::Boolean, UnaryOperator::Not)?; assert_eq!( evaluator.unary_eval(&DataValue::Boolean(true)), @@ -174,12 +219,18 @@ mod test { let evaluator = create(LogicalType::Boolean, UnaryOperator::Not)?; let mut cursor = Cursor::new(Vec::new()); let mut reference_tables = ReferenceTables::new(); + let mut arena = crate::planner::TableArena::default(); - evaluator.encode(&mut cursor, false, &mut reference_tables)?; + evaluator.encode(&mut cursor, false, &mut reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; assert_eq!( - UnaryEvaluatorBox::decode::(&mut cursor, None, &reference_tables)?, + UnaryEvaluatorRef::decode::( + &mut cursor, + None, + &reference_tables, + &mut arena + )?, evaluator ); diff --git a/src/types/evaluator/utf8.rs b/src/types/evaluator/utf8.rs index 09450917..e95ba894 100644 --- a/src/types/evaluator/utf8.rs +++ b/src/types/evaluator/utf8.rs @@ -13,309 +13,249 @@ // limitations under the License. use crate::errors::DatabaseError; -use crate::types::evaluator::cast::{cast_fail, to_char, to_varchar}; -use crate::types::evaluator::BinaryEvaluator; +use crate::types::evaluator::cast::{to_char, to_varchar}; use crate::types::evaluator::DataValue; use crate::types::value::Utf8Type; use crate::types::CharLengthUnits; -use crate::types::LogicalType; -use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; use ordered_float::OrderedFloat; +#[cfg(feature = "decimal")] use rust_decimal::Decimal; use std::hint; use std::str::FromStr; -#[derive(Debug)] -pub struct Utf8GtBinaryEvaluator; -#[derive(Debug)] -pub struct Utf8GtEqBinaryEvaluator; -#[derive(Debug)] -pub struct Utf8LtBinaryEvaluator; -#[derive(Debug)] -pub struct Utf8LtEqBinaryEvaluator; -#[derive(Debug)] -pub struct Utf8EqBinaryEvaluator; -#[derive(Debug)] -pub struct Utf8NotEqBinaryEvaluator; -#[derive(Debug)] -pub struct Utf8StringConcatBinaryEvaluator; -#[derive(Debug)] -pub struct Utf8LikeBinaryEvaluator { - pub(crate) escape_char: Option, -} -#[derive(Debug)] -pub struct Utf8NotLikeBinaryEvaluator { - pub(crate) escape_char: Option, -} -impl BinaryEvaluator for Utf8GtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Utf8 { value: v1, .. }, DataValue::Utf8 { value: v2, .. }) => { - DataValue::Boolean(v1 > v2) - } - (DataValue::Utf8 { .. }, DataValue::Null) - | (DataValue::Null, DataValue::Utf8 { .. }) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for Utf8GtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Utf8 { value: v1, .. }, DataValue::Utf8 { value: v2, .. }) => { - DataValue::Boolean(v1 >= v2) - } - (DataValue::Utf8 { .. }, DataValue::Null) - | (DataValue::Null, DataValue::Utf8 { .. }) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for Utf8LtBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Utf8 { value: v1, .. }, DataValue::Utf8 { value: v2, .. }) => { - DataValue::Boolean(v1 < v2) - } - (DataValue::Utf8 { .. }, DataValue::Null) - | (DataValue::Null, DataValue::Utf8 { .. }) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } +macro_rules! utf8_binary { + ($name:ident, $body:expr) => { + pub fn $name(left: &DataValue, right: &DataValue) -> Result { + Ok(match (left, right) { + (DataValue::Utf8 { value: v1, .. }, DataValue::Utf8 { value: v2, .. }) => { + $body(v1, v2) + } + (DataValue::Utf8 { .. }, DataValue::Null) + | (DataValue::Null, DataValue::Utf8 { .. }) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) + } + }; } -crate::define_cast_evaluator!( - Utf8ToBooleanCastEvaluator { - from: LogicalType - }, - DataValue::Utf8 { value, .. } => { +utf8_binary!(utf8_gt_binary_eval, |v1: &String, v2: &String| { + DataValue::Boolean(v1 > v2) +}); +utf8_binary!(utf8_gt_eq_binary_eval, |v1: &String, v2: &String| { + DataValue::Boolean(v1 >= v2) +}); +utf8_binary!(utf8_lt_binary_eval, |v1: &String, v2: &String| { + DataValue::Boolean(v1 < v2) +}); + +crate::define_cast_evaluator!(utf8_to_boolean_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::Boolean(bool::from_str(value)?)) } ); -crate::define_cast_evaluator!(Utf8ToTinyintCastEvaluator, DataValue::Utf8 { value, .. } => { +crate::define_cast_evaluator!(utf8_to_tinyint_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::Int8(i8::from_str(value)?)) }); -crate::define_cast_evaluator!(Utf8ToUTinyintCastEvaluator, DataValue::Utf8 { value, .. } => { +crate::define_cast_evaluator!(utf8_to_utinyint_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::UInt8(u8::from_str(value)?)) }); -crate::define_cast_evaluator!(Utf8ToSmallintCastEvaluator, DataValue::Utf8 { value, .. } => { +crate::define_cast_evaluator!(utf8_to_smallint_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::Int16(i16::from_str(value)?)) }); -crate::define_cast_evaluator!(Utf8ToUSmallintCastEvaluator, DataValue::Utf8 { value, .. } => { +crate::define_cast_evaluator!(utf8_to_usmallint_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::UInt16(u16::from_str(value)?)) }); -crate::define_cast_evaluator!(Utf8ToIntegerCastEvaluator, DataValue::Utf8 { value, .. } => { +crate::define_cast_evaluator!(utf8_to_integer_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::Int32(i32::from_str(value)?)) }); -crate::define_cast_evaluator!(Utf8ToUIntegerCastEvaluator, DataValue::Utf8 { value, .. } => { +crate::define_cast_evaluator!(utf8_to_uinteger_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::UInt32(u32::from_str(value)?)) }); -crate::define_cast_evaluator!(Utf8ToBigintCastEvaluator, DataValue::Utf8 { value, .. } => { +crate::define_cast_evaluator!(utf8_to_bigint_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::Int64(i64::from_str(value)?)) }); -crate::define_cast_evaluator!(Utf8ToUBigintCastEvaluator, DataValue::Utf8 { value, .. } => { +crate::define_cast_evaluator!(utf8_to_ubigint_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::UInt64(u64::from_str(value)?)) }); -crate::define_cast_evaluator!(Utf8ToFloatCastEvaluator, DataValue::Utf8 { value, .. } => { +crate::define_cast_evaluator!(utf8_to_float_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::Float32(OrderedFloat(f32::from_str(value)?))) }); -crate::define_cast_evaluator!(Utf8ToDoubleCastEvaluator, DataValue::Utf8 { value, .. } => { +crate::define_cast_evaluator!(utf8_to_double_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::Float64(OrderedFloat(f64::from_str(value)?))) }); crate::define_cast_evaluator!( - Utf8ToCharCastEvaluator { + utf8_to_char_cast_eval { len: u32, unit: CharLengthUnits }, DataValue::Utf8 { value, .. } => |this| to_char(value.clone(), this.len, this.unit) ); crate::define_cast_evaluator!( - Utf8ToVarcharCastEvaluator { + utf8_to_varchar_cast_eval { len: Option, unit: CharLengthUnits }, DataValue::Utf8 { value, .. } => |this| to_varchar(value.clone(), this.len, this.unit) ); -crate::define_cast_evaluator!(Utf8ToDateCastEvaluator, DataValue::Utf8 { value, .. } => { - Ok(DataValue::Date32( - NaiveDate::parse_from_str(value, crate::types::value::DATE_FMT)?.num_days_from_ce(), - )) -}); -crate::define_cast_evaluator!(Utf8ToDatetimeCastEvaluator, DataValue::Utf8 { value, .. } => { - let value = NaiveDateTime::parse_from_str(value, crate::types::value::DATE_TIME_FMT) - .or_else(|_| { - NaiveDate::parse_from_str(value, crate::types::value::DATE_FMT) - .map(|date| date.and_hms_opt(0, 0, 0).unwrap()) - })? - .and_utc() - .timestamp(); +#[cfg(feature = "time")] +mod chrono_cast { + use super::DataValue; + use crate::types::evaluator::cast::cast_fail; + use crate::types::LogicalType; + use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; - Ok(DataValue::Date64(value)) -}); -crate::define_cast_evaluator!( - Utf8ToTimeCastEvaluator { - precision: Option - }, - DataValue::Utf8 { value, .. } => |this| { - let precision = this.precision.unwrap_or(0); - let fmt = if precision == 0 { - crate::types::value::TIME_FMT - } else { - crate::types::value::TIME_FMT_WITHOUT_ZONE - }; - let (value, nano) = match precision { - 0 => ( - NaiveTime::parse_from_str(value, fmt) - .map(|time| time.num_seconds_from_midnight())?, - 0, - ), - _ => NaiveTime::parse_from_str(value, fmt) - .map(|time| (time.num_seconds_from_midnight(), time.nanosecond()))?, - }; + crate::define_cast_evaluator!(utf8_to_date_cast_eval, DataValue::Utf8 { value, .. } => { + Ok(DataValue::Date32( + NaiveDate::parse_from_str(value, crate::types::value::DATE_FMT)?.num_days_from_ce(), + )) + }); - Ok(DataValue::Time32(DataValue::pack(value, nano, precision), precision)) - } -); -crate::define_cast_evaluator!( - Utf8ToTimestampCastEvaluator { - precision: Option, - zone: bool, - to: LogicalType - }, - DataValue::Utf8 { value, .. } => |this| { - let precision = this.precision.unwrap_or(0); - let fmt = match (precision, this.zone) { - (0, false) => crate::types::value::DATE_TIME_FMT, - (0, true) => crate::types::value::TIME_STAMP_FMT_WITHOUT_PRECISION, - (3 | 6 | 9, false) => crate::types::value::TIME_STAMP_FMT_WITHOUT_ZONE, - _ => crate::types::value::TIME_STAMP_FMT_WITH_ZONE, - }; - let complete_value = if this.zone { - if value.contains("+") { - value.clone() + crate::define_cast_evaluator!(utf8_to_datetime_cast_eval, DataValue::Utf8 { value, .. } => { + let value = NaiveDateTime::parse_from_str(value, crate::types::value::DATE_TIME_FMT) + .or_else(|_| { + NaiveDate::parse_from_str(value, crate::types::value::DATE_FMT) + .map(|date| date.and_hms_opt(0, 0, 0).unwrap()) + })? + .and_utc() + .timestamp(); + + Ok(DataValue::Date64(value)) + }); + + crate::define_cast_evaluator!( + utf8_to_time_cast_eval { + precision: Option + }, + DataValue::Utf8 { value, .. } => |this| { + let precision = this.precision.unwrap_or(0); + let fmt = if precision == 0 { + crate::types::value::TIME_FMT } else { - format!("{value}+00:00") + crate::types::value::TIME_FMT_WITHOUT_ZONE + }; + let (value, nano) = match precision { + 0 => ( + NaiveTime::parse_from_str(value, fmt) + .map(|time| time.num_seconds_from_midnight())?, + 0, + ), + _ => NaiveTime::parse_from_str(value, fmt) + .map(|time| (time.num_seconds_from_midnight(), time.nanosecond()))?, + }; + + Ok(DataValue::Time32(DataValue::pack(value, nano, precision), precision)) + } + ); + + crate::define_cast_evaluator!( + utf8_to_timestamp_cast_eval { + precision: Option, + zone: bool + }, + DataValue::Utf8 { value, .. } => |this| { + let precision = this.precision.unwrap_or(0); + let target_type = || LogicalType::TimeStamp(this.precision, this.zone); + let fmt = match (precision, this.zone) { + (0, false) => crate::types::value::DATE_TIME_FMT, + (0, true) => crate::types::value::TIME_STAMP_FMT_WITHOUT_PRECISION, + (3 | 6 | 9, false) => crate::types::value::TIME_STAMP_FMT_WITHOUT_ZONE, + _ => crate::types::value::TIME_STAMP_FMT_WITH_ZONE, + }; + let complete_value = if this.zone { + if value.contains("+") { + value.clone() + } else { + format!("{value}+00:00") + } + } else { + value.clone() + }; + + if !this.zone { + let value = NaiveDateTime::parse_from_str(&complete_value, fmt)?.and_utc(); + let value = match precision { + 3 => value.timestamp_millis(), + 6 => value.timestamp_micros(), + 9 => value + .timestamp_nanos_opt() + .ok_or_else(|| cast_fail(target_type(), target_type()))?, + 0 => value.timestamp(), + _ => unreachable!(), + }; + + return Ok(DataValue::Time64(value, precision, false)); } - } else { - value.clone() - }; - if !this.zone { - let value = NaiveDateTime::parse_from_str(&complete_value, fmt)?.and_utc(); + let value = DateTime::parse_from_str(&complete_value, fmt); let value = match precision { - 3 => value.timestamp_millis(), - 6 => value.timestamp_micros(), + 3 => value.map(|date_time| date_time.timestamp_millis())?, + 6 => value.map(|date_time| date_time.timestamp_micros())?, 9 => value - .timestamp_nanos_opt() - .ok_or_else(|| cast_fail(this.to.clone(), this.to.clone()))?, - 0 => value.timestamp(), + .map(|date_time| date_time.timestamp_nanos_opt())? + .ok_or_else(|| cast_fail(target_type(), target_type()))?, + 0 => value.map(|date_time| date_time.timestamp())?, _ => unreachable!(), }; - return Ok(DataValue::Time64(value, precision, false)); + Ok(DataValue::Time64(value, precision, this.zone)) } + ); +} - let value = DateTime::parse_from_str(&complete_value, fmt); - let value = match precision { - 3 => value.map(|date_time| date_time.timestamp_millis())?, - 6 => value.map(|date_time| date_time.timestamp_micros())?, - 9 => value - .map(|date_time| date_time.timestamp_nanos_opt())? - .ok_or_else(|| cast_fail(this.to.clone(), this.to.clone()))?, - 0 => value.map(|date_time| date_time.timestamp())?, - _ => unreachable!(), - }; - - Ok(DataValue::Time64(value, precision, this.zone)) - } -); -crate::define_cast_evaluator!(Utf8ToDecimalCastEvaluator, DataValue::Utf8 { value, .. } => { +#[cfg(feature = "time")] +pub use chrono_cast::*; +#[cfg(feature = "decimal")] +crate::define_cast_evaluator!(utf8_to_decimal_cast_eval, DataValue::Utf8 { value, .. } => { Ok(DataValue::Decimal(Decimal::from_str(value)?)) }); -impl BinaryEvaluator for Utf8LtEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Utf8 { value: v1, .. }, DataValue::Utf8 { value: v2, .. }) => { - DataValue::Boolean(v1 <= v2) - } - (DataValue::Utf8 { .. }, DataValue::Null) - | (DataValue::Null, DataValue::Utf8 { .. }) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for Utf8EqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Utf8 { value: v1, .. }, DataValue::Utf8 { value: v2, .. }) => { - DataValue::Boolean(v1 == v2) - } - (DataValue::Utf8 { .. }, DataValue::Null) - | (DataValue::Null, DataValue::Utf8 { .. }) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for Utf8NotEqBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Utf8 { value: v1, .. }, DataValue::Utf8 { value: v2, .. }) => { - DataValue::Boolean(v1 != v2) - } - (DataValue::Utf8 { .. }, DataValue::Null) - | (DataValue::Null, DataValue::Utf8 { .. }) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for Utf8StringConcatBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Utf8 { value: v1, .. }, DataValue::Utf8 { value: v2, .. }) => { - DataValue::Utf8 { - value: v1.clone() + v2, - ty: Utf8Type::Variable(None), - unit: CharLengthUnits::Characters, - } - } - (DataValue::Utf8 { .. }, DataValue::Null) - | (DataValue::Null, DataValue::Utf8 { .. }) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } -} -impl BinaryEvaluator for Utf8LikeBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Utf8 { value, .. }, DataValue::Utf8 { value: pattern, .. }) => { - DataValue::Boolean(string_like(value, pattern, self.escape_char)) - } - (DataValue::Utf8 { .. }, DataValue::Null) - | (DataValue::Null, DataValue::Utf8 { .. }) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) +utf8_binary!(utf8_lt_eq_binary_eval, |v1: &String, v2: &String| { + DataValue::Boolean(v1 <= v2) +}); +utf8_binary!(utf8_eq_binary_eval, |v1: &String, v2: &String| { + DataValue::Boolean(v1 == v2) +}); +utf8_binary!(utf8_not_eq_binary_eval, |v1: &String, v2: &String| { + DataValue::Boolean(v1 != v2) +}); +utf8_binary!( + utf8_string_concat_binary_eval, + |v1: &String, v2: &String| { + DataValue::Utf8 { + value: v1.clone() + v2, + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + } } +); +pub fn utf8_like_binary_eval( + escape_char: Option, + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match (left, right) { + (DataValue::Utf8 { value, .. }, DataValue::Utf8 { value: pattern, .. }) => { + DataValue::Boolean(string_like(value, pattern, escape_char)) + } + (DataValue::Utf8 { .. }, DataValue::Null) + | (DataValue::Null, DataValue::Utf8 { .. }) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } -impl BinaryEvaluator for Utf8NotLikeBinaryEvaluator { - fn binary_eval(&self, left: &DataValue, right: &DataValue) -> Result { - Ok(match (left, right) { - (DataValue::Utf8 { value, .. }, DataValue::Utf8 { value: pattern, .. }) => { - DataValue::Boolean(!string_like(value, pattern, self.escape_char)) - } - (DataValue::Utf8 { .. }, DataValue::Null) - | (DataValue::Null, DataValue::Utf8 { .. }) - | (DataValue::Null, DataValue::Null) => DataValue::Null, - _ => unsafe { hint::unreachable_unchecked() }, - }) - } +pub fn utf8_not_like_binary_eval( + escape_char: Option, + left: &DataValue, + right: &DataValue, +) -> Result { + Ok(match (left, right) { + (DataValue::Utf8 { value, .. }, DataValue::Utf8 { value: pattern, .. }) => { + DataValue::Boolean(!string_like(value, pattern, escape_char)) + } + (DataValue::Utf8 { .. }, DataValue::Null) + | (DataValue::Null, DataValue::Utf8 { .. }) + | (DataValue::Null, DataValue::Null) => DataValue::Null, + _ => unsafe { hint::unreachable_unchecked() }, + }) } fn string_like(value: &str, pattern: &str, escape_char: Option) -> bool { @@ -407,7 +347,7 @@ fn next_char_at(input: &str, index: usize) -> Option<(char, usize)> { #[cfg(all(test, not(target_arch = "wasm32")))] mod test { use super::*; - use crate::types::evaluator::{BinaryEvaluator, CastEvaluator}; + use chrono::{Datelike, Timelike}; fn utf8(value: &str) -> DataValue { DataValue::Utf8 { @@ -420,27 +360,19 @@ mod test { #[test] fn test_utf8_binary_evaluators() { assert_eq!( - Utf8LtBinaryEvaluator - .binary_eval(&utf8("a"), &utf8("b")) - .unwrap(), + utf8_lt_binary_eval(&utf8("a"), &utf8("b")).unwrap(), DataValue::Boolean(true) ); assert_eq!( - Utf8StringConcatBinaryEvaluator - .binary_eval(&utf8("ab"), &utf8("cd")) - .unwrap(), + utf8_string_concat_binary_eval(&utf8("ab"), &utf8("cd")).unwrap(), utf8("abcd") ); assert_eq!( - Utf8LikeBinaryEvaluator { escape_char: None } - .binary_eval(&utf8("kite"), &utf8("ki%")) - .unwrap(), + utf8_like_binary_eval(None, &utf8("kite"), &utf8("ki%")).unwrap(), DataValue::Boolean(true) ); assert_eq!( - Utf8NotLikeBinaryEvaluator { escape_char: None } - .binary_eval(&utf8("kite"), &utf8("ki%")) - .unwrap(), + utf8_not_like_binary_eval(None, &utf8("kite"), &utf8("ki%")).unwrap(), DataValue::Boolean(false) ); } @@ -496,60 +428,51 @@ mod test { #[test] fn test_utf8_cast_evaluators() { assert_eq!( - Utf8ToBooleanCastEvaluator { - from: LogicalType::Varchar(None, CharLengthUnits::Characters), - } - .eval_cast(&utf8("true")) - .unwrap(), + utf8_to_boolean_cast_eval(&utf8("true")).unwrap(), DataValue::Boolean(true) ); assert_eq!( - Utf8ToTinyintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + utf8_to_tinyint_cast_eval(&utf8("1")).unwrap(), DataValue::Int8(1) ); assert_eq!( - Utf8ToUTinyintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + utf8_to_utinyint_cast_eval(&utf8("1")).unwrap(), DataValue::UInt8(1) ); assert_eq!( - Utf8ToSmallintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + utf8_to_smallint_cast_eval(&utf8("1")).unwrap(), DataValue::Int16(1) ); assert_eq!( - Utf8ToUSmallintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + utf8_to_usmallint_cast_eval(&utf8("1")).unwrap(), DataValue::UInt16(1) ); assert_eq!( - Utf8ToIntegerCastEvaluator.eval_cast(&utf8("1")).unwrap(), + utf8_to_integer_cast_eval(&utf8("1")).unwrap(), DataValue::Int32(1) ); assert_eq!( - Utf8ToUIntegerCastEvaluator.eval_cast(&utf8("1")).unwrap(), + utf8_to_uinteger_cast_eval(&utf8("1")).unwrap(), DataValue::UInt32(1) ); assert_eq!( - Utf8ToBigintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + utf8_to_bigint_cast_eval(&utf8("1")).unwrap(), DataValue::Int64(1) ); assert_eq!( - Utf8ToUBigintCastEvaluator.eval_cast(&utf8("1")).unwrap(), + utf8_to_ubigint_cast_eval(&utf8("1")).unwrap(), DataValue::UInt64(1) ); assert_eq!( - Utf8ToFloatCastEvaluator.eval_cast(&utf8("1.5")).unwrap(), + utf8_to_float_cast_eval(&utf8("1.5")).unwrap(), DataValue::Float32(OrderedFloat(1.5)) ); assert_eq!( - Utf8ToDoubleCastEvaluator.eval_cast(&utf8("1.5")).unwrap(), + utf8_to_double_cast_eval(&utf8("1.5")).unwrap(), DataValue::Float64(OrderedFloat(1.5)) ); assert_eq!( - Utf8ToCharCastEvaluator { - len: 2, - unit: CharLengthUnits::Characters, - } - .eval_cast(&utf8("ab")) - .unwrap(), + utf8_to_char_cast_eval(2, CharLengthUnits::Characters, &utf8("ab")).unwrap(), DataValue::Utf8 { value: "ab".to_string(), ty: Utf8Type::Fixed(2), @@ -557,12 +480,7 @@ mod test { } ); assert_eq!( - Utf8ToVarcharCastEvaluator { - len: Some(2), - unit: CharLengthUnits::Characters, - } - .eval_cast(&utf8("ab")) - .unwrap(), + utf8_to_varchar_cast_eval(Some(2), CharLengthUnits::Characters, &utf8("ab")).unwrap(), DataValue::Utf8 { value: "ab".to_string(), ty: Utf8Type::Variable(Some(2)), @@ -570,9 +488,7 @@ mod test { } ); assert_eq!( - Utf8ToDateCastEvaluator - .eval_cast(&utf8("2024-01-02")) - .unwrap(), + utf8_to_date_cast_eval(&utf8("2024-01-02")).unwrap(), DataValue::Date32( chrono::NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() @@ -580,9 +496,7 @@ mod test { ) ); assert_eq!( - Utf8ToDatetimeCastEvaluator - .eval_cast(&utf8("2024-01-02 03:04:05")) - .unwrap(), + utf8_to_datetime_cast_eval(&utf8("2024-01-02 03:04:05")).unwrap(), DataValue::Date64( chrono::NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() @@ -593,15 +507,11 @@ mod test { ) ); assert_eq!( - Utf8ToTimeCastEvaluator { precision: Some(0) } - .eval_cast(&utf8("03:04:05")) - .unwrap(), + utf8_to_time_cast_eval(Some(0), &utf8("03:04:05")).unwrap(), DataValue::Time32(DataValue::pack(3 * 3600 + 4 * 60 + 5, 0, 0), 0) ); assert_eq!( - Utf8ToTimeCastEvaluator { precision: Some(3) } - .eval_cast(&utf8("03:04:05.123")) - .unwrap(), + utf8_to_time_cast_eval(Some(3), &utf8("03:04:05.123")).unwrap(), { let time = chrono::NaiveTime::parse_from_str( "03:04:05.123", @@ -615,13 +525,7 @@ mod test { } ); assert_eq!( - Utf8ToTimestampCastEvaluator { - precision: Some(3), - zone: false, - to: LogicalType::TimeStamp(Some(3), false), - } - .eval_cast(&utf8("2024-01-02 03:04:05.123")) - .unwrap(), + utf8_to_timestamp_cast_eval(Some(3), false, &utf8("2024-01-02 03:04:05.123")).unwrap(), DataValue::Time64( chrono::NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() @@ -634,13 +538,7 @@ mod test { ) ); assert_eq!( - Utf8ToTimestampCastEvaluator { - precision: Some(0), - zone: true, - to: LogicalType::TimeStamp(Some(0), true), - } - .eval_cast(&utf8("2024-01-02 03:04:05+00:00")) - .unwrap(), + utf8_to_timestamp_cast_eval(Some(0), true, &utf8("2024-01-02 03:04:05+00:00")).unwrap(), DataValue::Time64( chrono::NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() @@ -653,9 +551,7 @@ mod test { ) ); assert_eq!( - Utf8ToDecimalCastEvaluator - .eval_cast(&utf8("12.34")) - .unwrap(), + utf8_to_decimal_cast_eval(&utf8("12.34")).unwrap(), DataValue::Decimal(Decimal::from_str("12.34").unwrap()) ); } diff --git a/src/types/index.rs b/src/types/index.rs index 1da63242..2404c285 100644 --- a/src/types/index.rs +++ b/src/types/index.rs @@ -17,6 +17,7 @@ use crate::errors::DatabaseError; use crate::expression::range_detacher::Range; use crate::expression::ScalarExpression; use crate::planner::operator::SortOption; +use crate::planner::PlanArena; use crate::types::serialize::TupleValueSerializableImpl; use crate::types::value::DataValue; use crate::types::{ColumnId, LogicalType}; @@ -24,13 +25,26 @@ use kite_sql_serde_macros::ReferenceSerialization; use std::collections::Bound; use std::fmt; use std::fmt::Formatter; -use std::sync::Arc; pub type IndexId = u32; -pub type IndexMetaRef = Arc; pub const INDEX_ID_LEN: usize = 4; +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +pub struct IndexMetaRef { + pos: usize, +} + +impl IndexMetaRef { + pub(crate) fn new(pos: usize) -> Self { + Self { pos } + } + + pub(crate) fn pos(self) -> usize { + self.pos + } +} + #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, ReferenceSerialization)] pub enum IndexType { PrimaryKey { is_multiple: bool }, @@ -80,16 +94,18 @@ impl IndexMeta { pub(crate) fn column_exprs( &self, table: &TableCatalog, + arena: &PlanArena, ) -> Result, DatabaseError> { let mut exprs = Vec::with_capacity(self.column_ids.len()); for column_id in self.column_ids.iter() { - if let Some((position, column)) = table + if let Some((position, column_ref)) = table .columns() + .copied() .enumerate() - .find(|(_, column)| column.id() == Some(*column_id)) + .find(|(_, column)| arena.column(*column).id() == Some(*column_id)) { - exprs.push(ScalarExpression::column_expr(column.clone(), position)); + exprs.push(ScalarExpression::column_expr(column_ref, position)); } else { return Err(DatabaseError::column_not_found(column_id.to_string())); } @@ -137,3 +153,9 @@ impl fmt::Display for IndexMeta { write!(f, "{}", self.name) } } + +impl fmt::Display for IndexMetaRef { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "#{}", self.pos) + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index bb88ad78..ee4785fc 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -19,35 +19,26 @@ pub mod tuple; pub mod tuple_builder; pub mod value; +#[cfg(feature = "time")] use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +#[cfg(feature = "decimal")] use rust_decimal::Decimal; -use serde::{Deserialize, Serialize}; use std::any::TypeId; use std::borrow::Cow; use std::cmp; use crate::errors::DatabaseError; use kite_sql_serde_macros::ReferenceSerialization; -use sqlparser::ast::{ExactNumberInfo, TimezoneInfo}; use ulid::Ulid; pub type ColumnId = Ulid; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum CharLengthUnits { Characters, Octets, } -impl From for CharLengthUnits { - fn from(value: sqlparser::ast::CharLengthUnits) -> Self { - match value { - sqlparser::ast::CharLengthUnits::Characters => Self::Characters, - sqlparser::ast::CharLengthUnits::Octets => Self::Octets, - } - } -} - impl std::fmt::Display for CharLengthUnits { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -57,20 +48,7 @@ impl std::fmt::Display for CharLengthUnits { } } -/// Sqlrs type conversion: -/// sqlparser::ast::DataType -> LogicalType -> arrow::datatypes::DataType -#[derive( - Debug, - Clone, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - Serialize, - Deserialize, - ReferenceSerialization, -)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, ReferenceSerialization)] pub enum LogicalType { SqlNull, Boolean, @@ -99,6 +77,11 @@ impl LogicalType { pub fn type_trans() -> Option { let type_id = TypeId::of::(); + #[cfg(feature = "decimal")] + if type_id == TypeId::of::() { + return Some(LogicalType::Decimal(None, None)); + } + if type_id == TypeId::of::() { Some(LogicalType::Boolean) } else if type_id == TypeId::of::() { @@ -121,17 +104,21 @@ impl LogicalType { Some(LogicalType::Float) } else if type_id == TypeId::of::() { Some(LogicalType::Double) - } else if type_id == TypeId::of::() { - Some(LogicalType::Date) - } else if type_id == TypeId::of::() { - Some(LogicalType::DateTime) - } else if type_id == TypeId::of::() { - Some(LogicalType::Time(Some(0))) - } else if type_id == TypeId::of::() { - Some(LogicalType::Decimal(None, None)) } else if type_id == TypeId::of::() { Some(LogicalType::Varchar(None, CharLengthUnits::Characters)) } else { + #[cfg(feature = "time")] + { + if type_id == TypeId::of::() { + return Some(LogicalType::Date); + } + if type_id == TypeId::of::() { + return Some(LogicalType::DateTime); + } + if type_id == TypeId::of::() { + return Some(LogicalType::Time(Some(0))); + } + } None } } @@ -244,25 +231,29 @@ impl LogicalType { if left.is_numeric() && right.is_numeric() { return LogicalType::combine_numeric_types(left, right); } - if matches!( - (left, right), - (LogicalType::Date, LogicalType::Varchar(..)) - | (LogicalType::Varchar(..), LogicalType::Date) - ) { - return Ok(Cow::Owned(LogicalType::Date)); - } - if matches!( - (left, right), - (LogicalType::Date, LogicalType::DateTime) | (LogicalType::DateTime, LogicalType::Date) - ) { - return Ok(Cow::Owned(LogicalType::DateTime)); - } - if matches!( - (left, right), - (LogicalType::DateTime, LogicalType::Varchar(..)) - | (LogicalType::Varchar(..), LogicalType::DateTime) - ) { - return Ok(Cow::Owned(LogicalType::DateTime)); + #[cfg(feature = "time")] + { + if matches!( + (left, right), + (LogicalType::Date, LogicalType::Varchar(..)) + | (LogicalType::Varchar(..), LogicalType::Date) + ) { + return Ok(Cow::Owned(LogicalType::Date)); + } + if matches!( + (left, right), + (LogicalType::Date, LogicalType::DateTime) + | (LogicalType::DateTime, LogicalType::Date) + ) { + return Ok(Cow::Owned(LogicalType::DateTime)); + } + if matches!( + (left, right), + (LogicalType::DateTime, LogicalType::Varchar(..)) + | (LogicalType::Varchar(..), LogicalType::DateTime) + ) { + return Ok(Cow::Owned(LogicalType::DateTime)); + } } if let (LogicalType::Char(..), LogicalType::Varchar(..)) | (LogicalType::Varchar(..), LogicalType::Char(..)) @@ -424,158 +415,6 @@ impl LogicalType { } } -/// sqlparser datatype to logical type -impl TryFrom for LogicalType { - type Error = DatabaseError; - - fn try_from(value: sqlparser::ast::DataType) -> Result { - match value { - sqlparser::ast::DataType::Char(char_len) - | sqlparser::ast::DataType::Character(char_len) => { - let mut len = 1; - let mut char_unit = None; - if let Some(char_len) = char_len { - match char_len { - sqlparser::ast::CharacterLength::IntegerLength { length, unit } => { - len = cmp::max(len, length); - char_unit = unit; - } - sqlparser::ast::CharacterLength::Max => { - return Err(DatabaseError::UnsupportedStmt( - "CHAR(MAX) is not supported".to_string(), - )); - } - } - } - Ok(LogicalType::Char( - len as u32, - char_unit - .map(Into::into) - .unwrap_or(CharLengthUnits::Characters), - )) - } - sqlparser::ast::DataType::CharVarying(varchar_len) - | sqlparser::ast::DataType::CharacterVarying(varchar_len) - | sqlparser::ast::DataType::Varchar(varchar_len) => { - let mut len = None; - let mut char_unit = None; - if let Some(varchar_len) = varchar_len { - match varchar_len { - sqlparser::ast::CharacterLength::IntegerLength { length, unit } => { - len = Some(length as u32); - char_unit = unit; - } - sqlparser::ast::CharacterLength::Max => { - return Err(DatabaseError::UnsupportedStmt( - "VARCHAR(MAX) is not supported".to_string(), - )); - } - } - } - Ok(LogicalType::Varchar( - len, - char_unit - .map(Into::into) - .unwrap_or(CharLengthUnits::Characters), - )) - } - sqlparser::ast::DataType::String(_) | sqlparser::ast::DataType::Text => { - Ok(LogicalType::Varchar(None, CharLengthUnits::Characters)) - } - sqlparser::ast::DataType::Float(_) - | sqlparser::ast::DataType::Float4 - | sqlparser::ast::DataType::Float32 - | sqlparser::ast::DataType::Real => Ok(LogicalType::Float), - sqlparser::ast::DataType::Double(_) - | sqlparser::ast::DataType::DoublePrecision - | sqlparser::ast::DataType::Float8 - | sqlparser::ast::DataType::Float64 => Ok(LogicalType::Double), - sqlparser::ast::DataType::TinyInt(_) => Ok(LogicalType::Tinyint), - sqlparser::ast::DataType::TinyIntUnsigned(_) | sqlparser::ast::DataType::UTinyInt => { - Ok(LogicalType::UTinyint) - } - sqlparser::ast::DataType::SmallInt(_) | sqlparser::ast::DataType::Int2(_) => { - Ok(LogicalType::Smallint) - } - sqlparser::ast::DataType::SmallIntUnsigned(_) - | sqlparser::ast::DataType::Int2Unsigned(_) - | sqlparser::ast::DataType::USmallInt => Ok(LogicalType::USmallint), - sqlparser::ast::DataType::Int(_) - | sqlparser::ast::DataType::Integer(_) - | sqlparser::ast::DataType::Int4(_) - | sqlparser::ast::DataType::Int32 => Ok(LogicalType::Integer), - sqlparser::ast::DataType::IntUnsigned(_) - | sqlparser::ast::DataType::IntegerUnsigned(_) - | sqlparser::ast::DataType::Int4Unsigned(_) - | sqlparser::ast::DataType::Unsigned - | sqlparser::ast::DataType::UnsignedInteger - | sqlparser::ast::DataType::UInt32 => Ok(LogicalType::UInteger), - sqlparser::ast::DataType::BigInt(_) - | sqlparser::ast::DataType::Int8(_) - | sqlparser::ast::DataType::Int64 => Ok(LogicalType::Bigint), - sqlparser::ast::DataType::BigIntUnsigned(_) - | sqlparser::ast::DataType::Int8Unsigned(_) - | sqlparser::ast::DataType::UBigInt - | sqlparser::ast::DataType::UInt64 => Ok(LogicalType::UBigint), - sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean), - sqlparser::ast::DataType::Date => Ok(LogicalType::Date), - sqlparser::ast::DataType::Datetime(precision) => { - if precision.is_some() { - return Err(DatabaseError::UnsupportedStmt( - "time's precision".to_string(), - )); - } - Ok(LogicalType::DateTime) - } - sqlparser::ast::DataType::Time(precision, info) => { - match precision { - Some(0..5) | None => (), - _ => { - return Err(DatabaseError::UnsupportedStmt( - "time's precision must be less than 5".to_string(), - )) - } - } - if !matches!(info, TimezoneInfo::None) { - return Err(DatabaseError::UnsupportedStmt( - "time's zone is not supported".to_string(), - )); - } - Ok(LogicalType::Time(precision)) - } - sqlparser::ast::DataType::Timestamp(precision, info) => { - let mut zone = false; - match precision { - Some(3 | 6 | 9) | None => (), - _ => { - return Err(DatabaseError::UnsupportedStmt( - "timestamp's precision must be 3,6,9".to_string(), - )) - } - } - if matches!(info, TimezoneInfo::WithTimeZone) { - zone = true; - } - Ok(LogicalType::TimeStamp(precision, zone)) - } - sqlparser::ast::DataType::Decimal(info) - | sqlparser::ast::DataType::DecimalUnsigned(info) - | sqlparser::ast::DataType::Dec(info) - | sqlparser::ast::DataType::DecUnsigned(info) - | sqlparser::ast::DataType::Numeric(info) => match info { - ExactNumberInfo::None => Ok(Self::Decimal(None, None)), - ExactNumberInfo::Precision(p) => Ok(Self::Decimal(Some(p as u8), None)), - ExactNumberInfo::PrecisionAndScale(p, s) => { - Ok(Self::Decimal(Some(p as u8), Some(s as u8))) - } - }, - other => Err(DatabaseError::UnsupportedStmt(format!( - "unsupported data type: {other}" - ))), - } - } -} - impl std::fmt::Display for LogicalType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -636,11 +475,17 @@ pub(crate) mod test { reference_tables: &mut ReferenceTables, logical_type: LogicalType, ) -> Result<(), DatabaseError> { - logical_type.encode(cursor, false, reference_tables)?; + let mut arena = crate::planner::TableArena::default(); + logical_type.encode(cursor, false, reference_tables, &arena)?; cursor.seek(SeekFrom::Start(0))?; assert_eq!( - LogicalType::decode::(cursor, None, reference_tables)?, + LogicalType::decode::( + cursor, + None, + reference_tables, + &mut arena, + )?, logical_type ); cursor.seek(SeekFrom::Start(0))?; diff --git a/src/types/serialize.rs b/src/types/serialize.rs index dff46585..22bd405a 100644 --- a/src/types/serialize.rs +++ b/src/types/serialize.rs @@ -16,10 +16,10 @@ use crate::errors::DatabaseError; use crate::types::value::{DataValue, Utf8Type}; use crate::types::CharLengthUnits; use crate::types::LogicalType; -use bumpalo::collections::Vec; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use kite_sql_serde_macros::ReferenceSerialization; use ordered_float::OrderedFloat; +#[cfg(feature = "decimal")] use rust_decimal::Decimal; use std::fmt::Debug; use std::io::{Cursor, Read, Seek, SeekFrom, Write}; @@ -27,11 +27,15 @@ use std::io::{Cursor, Read, Seek, SeekFrom, Write}; macro_rules! impl_tuple_value_serializable { ($name:ident, $variant:ident, $write_fn:expr, $read_fn:expr) => { impl TupleValueSerializable for $name { - fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + fn to_raw( + &self, + value: &DataValue, + writer: &mut W, + ) -> Result<(), DatabaseError> { let DataValue::$variant(v) = value else { unsafe { std::hint::unreachable_unchecked() } }; - ($write_fn)(writer, v)?; + ($write_fn)(&mut *writer, v)?; Ok(()) } @@ -43,7 +47,7 @@ macro_rules! impl_tuple_value_serializable { } pub trait TupleValueSerializable: Debug { - fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError>; + fn to_raw(&self, value: &DataValue, writer: &mut W) -> Result<(), DatabaseError>; #[allow(clippy::wrong_self_convention)] fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result; fn filling_value( @@ -92,7 +96,7 @@ pub enum TupleValueSerializableImpl { } impl TupleValueSerializable for TupleValueSerializableImpl { - fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + fn to_raw(&self, value: &DataValue, writer: &mut W) -> Result<(), DatabaseError> { match self { TupleValueSerializableImpl::Boolean => BooleanSerializable.to_raw(value, writer), TupleValueSerializableImpl::Int8 => Int8Serializable.to_raw(value, writer), @@ -289,29 +293,39 @@ struct SkipFixed(usize); #[derive(Debug)] struct SkipVariable; +fn write_spaces(writer: &mut W, mut len: usize) -> Result<(), DatabaseError> { + const SPACES: [u8; 64] = [b' '; 64]; + while len >= SPACES.len() { + writer.write_all(&SPACES)?; + len -= SPACES.len(); + } + writer.write_all(&SPACES[..len])?; + Ok(()) +} + // Int impl_tuple_value_serializable!( Int8Serializable, Int8, - |writer: &mut Vec, &value| writer.write_i8(value), + |writer: &mut dyn Write, &value| writer.write_i8(value), |reader: &mut Cursor<&[u8]>| reader.read_i8() ); impl_tuple_value_serializable!( Int16Serializable, Int16, - |writer: &mut Vec, &value| writer.write_i16::(value), + |writer: &mut dyn Write, &value| writer.write_i16::(value), |reader: &mut Cursor<&[u8]>| reader.read_i16::() ); impl_tuple_value_serializable!( Int32Serializable, Int32, - |writer: &mut Vec, &value| writer.write_i32::(value), + |writer: &mut dyn Write, &value| writer.write_i32::(value), |reader: &mut Cursor<&[u8]>| reader.read_i32::() ); impl_tuple_value_serializable!( Int64Serializable, Int64, - |writer: &mut Vec, &value| writer.write_i64::(value), + |writer: &mut dyn Write, &value| writer.write_i64::(value), |reader: &mut Cursor<&[u8]>| reader.read_i64::() ); @@ -319,25 +333,25 @@ impl_tuple_value_serializable!( impl_tuple_value_serializable!( UInt8Serializable, UInt8, - |writer: &mut Vec, &value| writer.write_u8(value), + |writer: &mut dyn Write, &value| writer.write_u8(value), |reader: &mut Cursor<&[u8]>| reader.read_u8() ); impl_tuple_value_serializable!( UInt16Serializable, UInt16, - |writer: &mut Vec, &value| writer.write_u16::(value), + |writer: &mut dyn Write, &value| writer.write_u16::(value), |reader: &mut Cursor<&[u8]>| reader.read_u16::() ); impl_tuple_value_serializable!( UInt32Serializable, UInt32, - |writer: &mut Vec, &value| writer.write_u32::(value), + |writer: &mut dyn Write, &value| writer.write_u32::(value), |reader: &mut Cursor<&[u8]>| reader.read_u32::() ); impl_tuple_value_serializable!( UInt64Serializable, UInt64, - |writer: &mut Vec, &value| writer.write_u64::(value), + |writer: &mut dyn Write, &value| writer.write_u64::(value), |reader: &mut Cursor<&[u8]>| reader.read_u64::() ); @@ -345,14 +359,14 @@ impl_tuple_value_serializable!( impl_tuple_value_serializable!( Float32Serializable, Float32, - |writer: &mut Vec, value: &OrderedFloat::| writer + |writer: &mut dyn Write, value: &OrderedFloat::| writer .write_f32::(value.into_inner()), |reader: &mut Cursor<&[u8]>| reader.read_f32::().map(OrderedFloat::) ); impl_tuple_value_serializable!( Float64Serializable, Float64, - |writer: &mut Vec, value: &OrderedFloat::| writer + |writer: &mut dyn Write, value: &OrderedFloat::| writer .write_f64::(value.into_inner()), |reader: &mut Cursor<&[u8]>| reader.read_f64::().map(OrderedFloat::) ); @@ -360,12 +374,12 @@ impl_tuple_value_serializable!( impl_tuple_value_serializable!( BooleanSerializable, Boolean, - |writer: &mut Vec, &value| writer.write_u8(value as u8), + |writer: &mut dyn Write, &value| writer.write_u8(value as u8), |reader: &mut Cursor<&[u8]>| reader.read_u8().map(|v| v != 0) ); impl TupleValueSerializable for CharSerializable { - fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + fn to_raw(&self, value: &DataValue, writer: &mut W) -> Result<(), DatabaseError> { let DataValue::Utf8 { value, unit, @@ -377,11 +391,12 @@ impl TupleValueSerializable for CharSerializable { match unit { CharLengthUnits::Characters => { let chars_len = *len as usize; - let v = format!("{value:chars_len$}"); - let bytes = v.as_bytes(); + let bytes = value.as_bytes(); + let spaces_len = chars_len.saturating_sub(value.chars().count()); - writer.write_u32::(bytes.len() as u32)?; + writer.write_u32::((bytes.len() + spaces_len) as u32)?; writer.write_all(bytes)?; + write_spaces(writer, spaces_len)?; } CharLengthUnits::Octets => { let octets_len = *len as usize; @@ -389,9 +404,7 @@ impl TupleValueSerializable for CharSerializable { debug_assert!(octets_len >= bytes.len()); writer.write_all(bytes)?; - for _ in 0..octets_len - bytes.len() { - writer.write_u8(b' ')?; - } + write_spaces(writer, octets_len - bytes.len())?; } } Ok(()) @@ -420,7 +433,7 @@ impl TupleValueSerializable for CharSerializable { } impl TupleValueSerializable for VarcharSerializable { - fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + fn to_raw(&self, value: &DataValue, writer: &mut W) -> Result<(), DatabaseError> { let DataValue::Utf8 { value, ty: Utf8Type::Variable(_), @@ -452,18 +465,18 @@ impl TupleValueSerializable for VarcharSerializable { impl_tuple_value_serializable!( DateSerializable, Date32, - |writer: &mut Vec, &value| writer.write_i32::(value), + |writer: &mut dyn Write, &value| writer.write_i32::(value), |reader: &mut Cursor<&[u8]>| reader.read_i32::() ); impl_tuple_value_serializable!( DateTimeSerializable, Date64, - |writer: &mut Vec, &value| writer.write_i64::(value), + |writer: &mut dyn Write, &value| writer.write_i64::(value), |reader: &mut Cursor<&[u8]>| reader.read_i64::() ); impl TupleValueSerializable for TimeSerializable { - fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + fn to_raw(&self, value: &DataValue, writer: &mut W) -> Result<(), DatabaseError> { let DataValue::Time32(v, ..) = value else { unsafe { std::hint::unreachable_unchecked() } }; @@ -481,7 +494,7 @@ impl TupleValueSerializable for TimeSerializable { } impl TupleValueSerializable for TimeStampSerializable { - fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + fn to_raw(&self, value: &DataValue, writer: &mut W) -> Result<(), DatabaseError> { let DataValue::Time64(v, ..) = value else { unsafe { std::hint::unreachable_unchecked() } }; @@ -499,10 +512,11 @@ impl TupleValueSerializable for TimeStampSerializable { } } +#[cfg(feature = "decimal")] impl_tuple_value_serializable!( DecimalSerializable, Decimal, - |writer: &mut Vec, &value: &Decimal| writer.write_all(&value.serialize()), + |writer: &mut dyn Write, &value: &Decimal| writer.write_all(&value.serialize()), |reader: &mut Cursor<&[u8]>| { let mut bytes = [0u8; 16]; reader.read_exact(&mut bytes)?; @@ -510,8 +524,24 @@ impl_tuple_value_serializable!( } ); +#[cfg(not(feature = "decimal"))] +impl TupleValueSerializable for DecimalSerializable { + fn to_raw(&self, _: &DataValue, _: &mut W) -> Result<(), DatabaseError> { + Err(DatabaseError::UnsupportedStmt( + "DECIMAL requires the `decimal` feature".to_string(), + )) + } + + fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result { + reader.seek(SeekFrom::Current(16))?; + Err(DatabaseError::UnsupportedStmt( + "DECIMAL requires the `decimal` feature".to_string(), + )) + } +} + impl TupleValueSerializable for SkipFixed { - fn to_raw(&self, _: &DataValue, _: &mut Vec) -> Result<(), DatabaseError> { + fn to_raw(&self, _: &DataValue, _: &mut W) -> Result<(), DatabaseError> { unreachable!(); } @@ -531,7 +561,7 @@ impl TupleValueSerializable for SkipFixed { } impl TupleValueSerializable for SkipVariable { - fn to_raw(&self, _: &DataValue, _: &mut Vec) -> Result<(), DatabaseError> { + fn to_raw(&self, _: &DataValue, _: &mut W) -> Result<(), DatabaseError> { unreachable!(); } diff --git a/src/types/tuple.rs b/src/types/tuple.rs index ae2822d5..f27db8d5 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -12,23 +12,68 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::ColumnRef; -use crate::db::ResultIter; +use crate::catalog::{ColumnCatalog, ColumnRef}; use crate::errors::DatabaseError; -use crate::storage::table_codec::BumpBytes; +use crate::planner::PlanArena; use crate::types::serialize::{TupleValueSerializable, TupleValueSerializableImpl}; use crate::types::value::DataValue; -use bumpalo::Bump; -use comfy_table::{Cell, Table}; use itertools::Itertools; +use std::borrow::Borrow; use std::io::Cursor; -use std::sync::Arc; const BITS_MAX_INDEX: usize = 8; pub type TupleId = DataValue; pub type Schema = Vec; -pub type SchemaRef = Arc; + +pub struct SchemaView<'a, 'p> { + schema: &'a Schema, + arena: &'a PlanArena<'p>, +} + +pub struct SchemaColumnIter<'a, 'p, 's> { + columns: std::slice::Iter<'s, ColumnRef>, + arena: &'a PlanArena<'p>, +} + +impl<'a> Iterator for SchemaColumnIter<'a, '_, '_> { + type Item = &'a ColumnCatalog; + + fn next(&mut self) -> Option { + self.columns.next().map(|column| self.arena.column(*column)) + } +} + +impl<'a, 'p> SchemaView<'a, 'p> { + pub fn new(schema: &'a Schema, arena: &'a PlanArena<'p>) -> Self { + Self { schema, arena } + } + + pub fn iter(&self) -> SchemaColumnIter<'a, 'p, '_> { + SchemaColumnIter { + columns: self.schema.iter(), + arena: self.arena, + } + } + + pub fn len(&self) -> usize { + self.schema.len() + } + + pub fn is_empty(&self) -> bool { + self.schema.is_empty() + } + + pub fn get(&self, index: usize) -> Option<&'a ColumnCatalog> { + self.schema + .get(index) + .map(|column| self.arena.column(*column)) + } + + pub fn position(&self, name: &str) -> Option { + self.iter().position(|column| column.name() == name) + } +} pub trait TupleLike { fn value_at(&self, index: usize) -> &DataValue; @@ -149,62 +194,73 @@ impl Tuple { } #[inline] - pub fn deserialize_from_into( + pub fn deserialize_from_into( &mut self, - deserializers: &[TupleValueSerializableImpl], + deserializers: I, bytes: &[u8], total_len: usize, - ) -> Result<(), DatabaseError> { + ) -> Result<(), DatabaseError> + where + I: IntoIterator, + S: Borrow, + { fn is_null(bits: u8, i: usize) -> bool { bits & (1 << (7 - i)) > 0 } let bits_len = (total_len + BITS_MAX_INDEX) / BITS_MAX_INDEX; self.values.clear(); - self.values.reserve(deserializers.len()); + self.values.reserve(total_len); let mut cursor = Cursor::new(&bytes[bits_len..]); - for (i, deserializer) in deserializers.iter().enumerate() { + for (i, deserializer) in deserializers.into_iter().enumerate() { if is_null(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) { self.values.push(DataValue::Null); continue; } - deserializer.filling_value(&mut cursor, &mut self.values)?; + deserializer + .borrow() + .filling_value(&mut cursor, &mut self.values)?; } Ok(()) } /// e.g.: bits(u8)..|data_0(len for utf8_1)|utf8_0|data_1| /// Tips: all len is u32 - pub fn serialize_to<'a>( + pub fn serialize_to( &self, - serializers: &[TupleValueSerializableImpl], - arena: &'a Bump, - ) -> Result, DatabaseError> { - debug_assert_eq!(self.values.len(), serializers.len()); - + serializers: I, + bytes: &mut Vec, + ) -> Result<(), DatabaseError> + where + I: IntoIterator, + S: Borrow, + { fn flip_bit(bits: u8, i: usize) -> u8 { bits | (1 << (7 - i)) } let values_len = self.values.len(); let bits_len = (values_len + BITS_MAX_INDEX) / BITS_MAX_INDEX; - let mut bytes = BumpBytes::new_in(arena); + let values_bytes_len = self + .values + .iter() + .map(DataValue::serialized_len_hint) + .sum::(); + bytes.clear(); + bytes.reserve(bits_len + values_bytes_len); bytes.resize(bits_len, 0u8); - let null_bytes: *mut BumpBytes = &mut bytes; - debug_assert_eq!(self.values.len(), serializers.len()); - for (i, (value, serializer)) in self.values.iter().zip(serializers.iter()).enumerate() { + for (i, (value, serializer)) in self.values.iter().zip(serializers).enumerate() { if value.is_null() { - let null_bytes = unsafe { &mut *null_bytes }; - null_bytes[i / BITS_MAX_INDEX] = - flip_bit(null_bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX); + bytes[i / BITS_MAX_INDEX] = flip_bit(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX); } else { - serializer.to_raw(value, &mut bytes)?; + serializer.borrow().to_raw(value, bytes)?; } } - Ok(bytes) + + Ok(()) } pub fn primary_projection(pk_indices: &[usize], values: &[DataValue]) -> TupleId { @@ -219,40 +275,13 @@ impl Tuple { } } -pub fn create_table(iter: I) -> Result { - let mut table = Table::new(); - let mut header = Vec::new(); - let schema = iter.schema().clone(); - - for col in schema.iter() { - header.push(Cell::new(col.full_name())); - } - table.set_header(header); - - for tuple in iter { - let tuple = tuple?; - debug_assert_eq!(schema.len(), tuple.values.len()); - - let cells = tuple - .values - .iter() - .map(|value| Cell::new(format!("{value}"))) - .collect_vec(); - - table.add_row(cells); - } - - Ok(table) -} - #[cfg(all(test, not(target_arch = "wasm32")))] mod tests { - use crate::catalog::{ColumnCatalog, ColumnDesc, ColumnRef}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::types::tuple::Tuple; use crate::types::value::{DataValue, Utf8Type}; use crate::types::CharLengthUnits; use crate::types::LogicalType; - use bumpalo::Bump; use itertools::Itertools; use ordered_float::OrderedFloat; use rust_decimal::Decimal; @@ -261,17 +290,17 @@ mod tests { #[test] fn test_tuple_serialize_to_and_deserialize_from() { let columns = Arc::new(vec![ - ColumnRef::from(ColumnCatalog::new( + ColumnCatalog::new( "c1".to_string(), false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c2".to_string(), false, ColumnDesc::new(LogicalType::UInteger, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c3".to_string(), false, ColumnDesc::new( @@ -281,58 +310,58 @@ mod tests { None, ) .unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c4".to_string(), false, ColumnDesc::new(LogicalType::Smallint, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c5".to_string(), false, ColumnDesc::new(LogicalType::USmallint, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c6".to_string(), false, ColumnDesc::new(LogicalType::Float, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c7".to_string(), false, ColumnDesc::new(LogicalType::Double, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c8".to_string(), false, ColumnDesc::new(LogicalType::Tinyint, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c9".to_string(), false, ColumnDesc::new(LogicalType::UTinyint, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c10".to_string(), false, ColumnDesc::new(LogicalType::Boolean, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c11".to_string(), false, ColumnDesc::new(LogicalType::DateTime, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c12".to_string(), false, ColumnDesc::new(LogicalType::Date, None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c13".to_string(), false, ColumnDesc::new(LogicalType::Decimal(None, None), None, false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c14".to_string(), false, ColumnDesc::new( @@ -342,8 +371,8 @@ mod tests { None, ) .unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c15".to_string(), false, ColumnDesc::new( @@ -353,8 +382,8 @@ mod tests { None, ) .unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( + ), + ColumnCatalog::new( "c16".to_string(), false, ColumnDesc::new( @@ -364,7 +393,18 @@ mod tests { None, ) .unwrap(), - )), + ), + ColumnCatalog::new( + "c17".to_string(), + false, + ColumnDesc::new( + LogicalType::Char(3, CharLengthUnits::Characters), + None, + false, + None, + ) + .unwrap(), + ), ]); let tuples = [ @@ -403,6 +443,11 @@ mod tests { ty: Utf8Type::Fixed(10), unit: CharLengthUnits::Octets, }, + DataValue::Utf8 { + value: "你".to_string(), + ty: Utf8Type::Fixed(3), + unit: CharLengthUnits::Characters, + }, ], ), Tuple::new( @@ -424,6 +469,7 @@ mod tests { DataValue::Null, DataValue::Null, DataValue::Null, + DataValue::Null, ], ), ]; @@ -431,19 +477,15 @@ mod tests { .iter() .map(|column| column.datatype().serializable()) .collect_vec(); - let columns = Arc::new(columns); - let arena = Bump::new(); + let mut bytes = Vec::new(); { let mut tuple_0 = Tuple { pk: tuples[0].pk.clone(), values: Vec::with_capacity(serializers.len()), }; + tuples[0].serialize_to(&serializers, &mut bytes).unwrap(); tuple_0 - .deserialize_from_into( - &serializers, - &tuples[0].serialize_to(&serializers, &arena).unwrap(), - columns.len(), - ) + .deserialize_from_into(&serializers, &bytes, columns.len()) .unwrap(); assert_eq!(tuples[0], tuple_0); @@ -453,12 +495,9 @@ mod tests { pk: tuples[1].pk.clone(), values: Vec::with_capacity(serializers.len()), }; + tuples[1].serialize_to(&serializers, &mut bytes).unwrap(); tuple_1 - .deserialize_from_into( - &serializers, - &tuples[1].serialize_to(&serializers, &arena).unwrap(), - columns.len(), - ) + .deserialize_from_into(&serializers, &bytes, columns.len()) .unwrap(); assert_eq!(tuples[1], tuple_1); @@ -475,12 +514,9 @@ mod tests { pk: tuples[0].pk.clone(), values: Vec::with_capacity(2), }; + tuples[0].serialize_to(&serializers, &mut bytes).unwrap(); tuple_2 - .deserialize_from_into( - &projection_serializers, - &tuples[0].serialize_to(&serializers, &arena).unwrap(), - columns.len(), - ) + .deserialize_from_into(&projection_serializers, &bytes, columns.len()) .unwrap(); assert_eq!( @@ -511,12 +547,11 @@ mod tests { pk: multi_pk_tuple.pk.clone(), values: Vec::with_capacity(serializers.len()), }; + multi_pk_tuple + .serialize_to(&serializers, &mut bytes) + .unwrap(); tuple_3 - .deserialize_from_into( - &multiple_pk_serializers, - &multi_pk_tuple.serialize_to(&serializers, &arena).unwrap(), - columns.len(), - ) + .deserialize_from_into(&multiple_pk_serializers, &bytes, columns.len()) .unwrap(); assert_eq!( diff --git a/src/types/tuple_builder.rs b/src/types/tuple_builder.rs index f25d03c2..b5fac8e5 100644 --- a/src/types/tuple_builder.rs +++ b/src/types/tuple_builder.rs @@ -12,20 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::catalog::PrimaryKeyIndices; use crate::errors::DatabaseError; -use crate::types::tuple::{Schema, Tuple}; +use crate::types::tuple::Tuple; use crate::types::value::{DataValue, Utf8Type}; use crate::types::CharLengthUnits; +use crate::types::LogicalType; pub struct TupleBuilder<'a> { - schema: &'a Schema, - pk_indices: Option<&'a PrimaryKeyIndices>, + column_types: Vec, + pk_indices: Option<&'a [usize]>, } impl<'a> TupleBuilder<'a> { - pub fn new(schema: &'a Schema, pk_indices: Option<&'a PrimaryKeyIndices>) -> Self { - TupleBuilder { schema, pk_indices } + pub fn new(column_types: Vec, pk_indices: Option<&'a [usize]>) -> Self { + TupleBuilder { + column_types, + pk_indices, + } } pub fn build_result(message: String) -> Tuple { @@ -52,7 +55,7 @@ impl<'a> TupleBuilder<'a> { &self, row: impl IntoIterator, ) -> Result { - let mut values = Vec::with_capacity(self.schema.len()); + let mut values = Vec::with_capacity(self.column_types.len()); for (i, value) in row.into_iter().enumerate() { values.push( @@ -61,10 +64,10 @@ impl<'a> TupleBuilder<'a> { ty: Utf8Type::Variable(None), unit: CharLengthUnits::Characters, } - .cast(self.schema[i].datatype())?, + .cast(&self.column_types[i])?, ); } - if values.len() != self.schema.len() { + if values.len() != self.column_types.len() { return Err(DatabaseError::MisMatch("types", "values")); } diff --git a/src/types/value.rs b/src/types/value.rs index 9c3d613a..a231a5b4 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -15,36 +15,56 @@ use super::LogicalType; use crate::errors::DatabaseError; use crate::storage::table_codec::{BumpBytes, BOUND_MAX_TAG, NOTNULL_TAG, NULL_TAG}; -use crate::types::evaluator::cast_create; +use crate::types::evaluator::cast::{cast_create, to_char, to_varchar}; use crate::types::CharLengthUnits; use byteorder::ReadBytesExt; -use chrono::format::{DelayedFormat, StrftimeItems}; -use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; +#[cfg(feature = "time")] +use chrono::{ + format::{DelayedFormat, StrftimeItems}, + DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc, +}; use itertools::Itertools; use ordered_float::OrderedFloat; +#[cfg(feature = "decimal")] use rust_decimal::Decimal; use std::borrow::Cow; use std::cmp::Ordering; use std::fmt::Formatter; use std::hash::Hash; -use std::io::{Read, Write}; -use std::sync::LazyLock; -use std::{cmp, fmt, mem}; - -static UNIX_DATETIME: LazyLock = - LazyLock::new(|| DateTime::from_timestamp(0, 0).unwrap().naive_utc()); - -static UNIX_TIME: LazyLock = LazyLock::new(|| NaiveTime::from_hms_opt(0, 0, 0).unwrap()); +use std::io::Read; +#[cfg(feature = "decimal")] +use std::mem; +use std::{cmp, fmt}; + +#[cfg(feature = "time")] +mod chrono_value { + use chrono::{DateTime, NaiveDateTime, NaiveTime}; + use std::sync::LazyLock; + + pub(super) static UNIX_DATETIME: LazyLock = + LazyLock::new(|| DateTime::from_timestamp(0, 0).unwrap().naive_utc()); + pub(super) static UNIX_TIME: LazyLock = + LazyLock::new(|| NaiveTime::from_hms_opt(0, 0, 0).unwrap()); + + pub const DATE_FMT: &str = "%Y-%m-%d"; + pub const DATE_TIME_FMT: &str = "%Y-%m-%d %H:%M:%S"; + pub const TIME_STAMP_FMT_WITHOUT_ZONE: &str = "%Y-%m-%d %H:%M:%S%.f"; + pub const TIME_STAMP_FMT_WITH_ZONE: &str = "%Y-%m-%d %H:%M:%S%.f%z"; + pub const TIME_STAMP_FMT_WITHOUT_PRECISION: &str = "%Y-%m-%d %H:%M:%S%z"; + pub const TIME_FMT: &str = "%H:%M:%S"; + pub const TIME_FMT_WITHOUT_ZONE: &str = "%H:%M:%S%.f"; + pub const TIME_FMT_WITH_ZONE: &str = "%H:%M:%S%.f%z"; + pub const TIME_FMT_WITHOUT_PRECISION: &str = "%H:%M:%S%z"; +} -pub const DATE_FMT: &str = "%Y-%m-%d"; -pub const DATE_TIME_FMT: &str = "%Y-%m-%d %H:%M:%S"; -pub const TIME_STAMP_FMT_WITHOUT_ZONE: &str = "%Y-%m-%d %H:%M:%S%.f"; -pub const TIME_STAMP_FMT_WITH_ZONE: &str = "%Y-%m-%d %H:%M:%S%.f%z"; -pub const TIME_STAMP_FMT_WITHOUT_PRECISION: &str = "%Y-%m-%d %H:%M:%S%z"; -pub const TIME_FMT: &str = "%H:%M:%S"; -pub const TIME_FMT_WITHOUT_ZONE: &str = "%H:%M:%S%.f"; -pub const TIME_FMT_WITH_ZONE: &str = "%H:%M:%S%.f%z"; -pub const TIME_FMT_WITHOUT_PRECISION: &str = "%H:%M:%S%z"; +#[cfg(feature = "time")] +pub use chrono_value::{ + DATE_FMT, DATE_TIME_FMT, TIME_FMT, TIME_FMT_WITHOUT_PRECISION, TIME_FMT_WITHOUT_ZONE, + TIME_FMT_WITH_ZONE, TIME_STAMP_FMT_WITHOUT_PRECISION, TIME_STAMP_FMT_WITHOUT_ZONE, + TIME_STAMP_FMT_WITH_ZONE, +}; +#[cfg(feature = "time")] +use chrono_value::{UNIX_DATETIME, UNIX_TIME}; pub const ONE_SEC_TO_NANO: u32 = 1_000_000_000; pub const ONE_DAY_TO_SEC: u32 = 86_400; @@ -52,7 +72,7 @@ pub const ONE_DAY_TO_SEC: u32 = 86_400; const ENCODE_GROUP_SIZE: usize = 8; const ENCODE_MARKER: u8 = 0xFF; -pub trait MemComparableBuffer: Write { +pub trait MemComparableBuffer { fn push_byte(&mut self, byte: u8); fn extend_bytes(&mut self, bytes: &[u8]); fn reserve_bytes(&mut self, size: usize); @@ -96,13 +116,13 @@ impl MemComparableBuffer for Vec { } } -#[derive(Clone, serde::Serialize, serde::Deserialize)] +#[derive(Clone)] pub enum Utf8Type { Variable(Option), Fixed(u32), } -#[derive(Clone, serde::Serialize, serde::Deserialize)] +#[derive(Clone)] pub enum DataValue { Null, Boolean(bool), @@ -127,6 +147,7 @@ pub enum DataValue { Date64(i64), Time32(u32, u64), Time64(i64, u64, bool), + #[cfg(feature = "decimal")] Decimal(Decimal), /// (values, is_upper) Tuple(Vec, bool), @@ -227,10 +248,20 @@ generate_get_option!(DataValue, u8 : UInt8(Option), u16 : UInt16(Option), u32 : UInt32(Option), - u64 : UInt64(Option), - decimal : Decimal(Option) + u64 : UInt64(Option) ); +#[cfg(feature = "decimal")] +impl DataValue { + pub fn decimal(&self) -> Option { + if let DataValue::Decimal(val) = self { + Some(*val) + } else { + None + } + } +} + impl PartialEq for DataValue { fn eq(&self, other: &Self) -> bool { use DataValue::*; @@ -274,7 +305,9 @@ impl PartialEq for DataValue { (Time32(..), _) => false, (Time64(v1, ..), Time64(v2, ..)) => v1.eq(v2), (Time64(..), _) => false, + #[cfg(feature = "decimal")] (Decimal(v1), Decimal(v2)) => v1.eq(v2), + #[cfg(feature = "decimal")] (Decimal(_), _) => false, (Tuple(values_1, is_upper_1), Tuple(values_2, is_upper_2)) => { values_1.eq(values_2) && is_upper_1.eq(is_upper_2) @@ -322,7 +355,9 @@ impl PartialOrd for DataValue { (Time32(..), _) => None, (Time64(v1, ..), Time64(v2, ..)) => v1.partial_cmp(v2), (Time64(..), _) => None, + #[cfg(feature = "decimal")] (Decimal(v1), Decimal(v2)) => v1.partial_cmp(v2), + #[cfg(feature = "decimal")] (Decimal(_), _) => None, (Tuple(..), _) => None, } @@ -331,7 +366,7 @@ impl PartialOrd for DataValue { macro_rules! encode_u { ($writer:ident, $u:expr) => { - $writer.write_all(&$u.to_be_bytes())? + $writer.extend_bytes(&$u.to_be_bytes()) }; } @@ -366,6 +401,7 @@ impl Hash for DataValue { Date64(v) => v.hash(state), Time32(v, ..) => v.hash(state), Time64(v, ..) => v.hash(state), + #[cfg(feature = "decimal")] Decimal(v) => v.hash(state), Tuple(values, is_upper) => { values.hash(state); @@ -375,6 +411,35 @@ impl Hash for DataValue { } } impl DataValue { + pub(crate) fn serialized_len_hint(&self) -> usize { + match self { + DataValue::Null => 0, + DataValue::Boolean(_) | DataValue::Int8(_) | DataValue::UInt8(_) => 1, + DataValue::Int16(_) | DataValue::UInt16(_) => 2, + DataValue::Int32(_) + | DataValue::UInt32(_) + | DataValue::Float32(_) + | DataValue::Date32(_) + | DataValue::Time32(_, _) => 4, + DataValue::Int64(_) + | DataValue::UInt64(_) + | DataValue::Float64(_) + | DataValue::Date64(_) + | DataValue::Time64(_, _, _) => 8, + DataValue::Utf8 { value, ty, unit } => match (ty, unit) { + (Utf8Type::Variable(_), _) => std::mem::size_of::() + value.len(), + (Utf8Type::Fixed(len), CharLengthUnits::Characters) => { + let spaces_len = (*len as usize).saturating_sub(value.chars().count()); + std::mem::size_of::() + value.len() + spaces_len + } + (Utf8Type::Fixed(len), CharLengthUnits::Octets) => *len as usize, + }, + #[cfg(feature = "decimal")] + DataValue::Decimal(_) => 16, + DataValue::Tuple(values, _) => values.iter().map(DataValue::serialized_len_hint).sum(), + } + } + pub fn float(&self) -> Option { if let DataValue::Float32(val) = self { Some(val.0) @@ -407,30 +472,6 @@ impl DataValue { } } - pub fn date(&self) -> Option { - if let DataValue::Date32(val) = self { - NaiveDate::from_num_days_from_ce_opt(*val) - } else { - None - } - } - - pub fn datetime(&self) -> Option { - if let DataValue::Date64(val) = self { - DateTime::from_timestamp(*val, 0).map(|dt| dt.naive_utc()) - } else { - None - } - } - - pub fn time(&self) -> Option { - if let DataValue::Time32(val, ..) = self { - NaiveTime::from_num_seconds_from_midnight_opt(*val, 0) - } else { - None - } - } - #[inline] pub(crate) fn check_string_len(string: &str, len: usize, unit: CharLengthUnits) -> bool { match unit { @@ -475,6 +516,7 @@ impl DataValue { unit: CharLengthUnits::Octets, }, ) => Self::check_string_len(val, *len as usize, CharLengthUnits::Octets), + #[cfg(feature = "decimal")] (LogicalType::Decimal(full_len, scale_len), DataValue::Decimal(val)) => { if let Some(len) = full_len { let mantissa = val.mantissa().abs(); @@ -527,22 +569,6 @@ impl DataValue { (b, scaled_a * (1000000000 / 10_u32.pow(precision as u32))) } - pub(crate) fn format_date(value: i32) -> Option { - Self::date_format(value).map(|fmt| format!("{fmt}")) - } - - pub(crate) fn format_datetime(value: i64) -> Option { - Self::date_time_format(value).map(|fmt| format!("{fmt}")) - } - - pub(crate) fn format_time(value: u32, precision: u64) -> Option { - Self::time_format(value, precision).map(|fmt| format!("{fmt}")) - } - - pub(crate) fn format_timestamp(value: i64, precision: u64) -> Option { - Self::time_stamp_format(value, precision, false).map(|fmt| format!("{fmt}")) - } - #[inline] pub fn is_null(&self) -> bool { matches!(self, DataValue::Null) @@ -573,26 +599,66 @@ impl DataValue { ty: Utf8Type::Variable(*len), unit: *unit, }, - LogicalType::Date => DataValue::Date32(UNIX_DATETIME.num_days_from_ce()), - LogicalType::DateTime => DataValue::Date64(UNIX_DATETIME.and_utc().timestamp()), + LogicalType::Date => { + #[cfg(feature = "time")] + { + DataValue::Date32(UNIX_DATETIME.num_days_from_ce()) + } + #[cfg(not(feature = "time"))] + { + DataValue::Date32(0) + } + } + LogicalType::DateTime => { + #[cfg(feature = "time")] + { + DataValue::Date64(UNIX_DATETIME.and_utc().timestamp()) + } + #[cfg(not(feature = "time"))] + { + DataValue::Date64(0) + } + } LogicalType::Time(precision) => match precision { + #[cfg(feature = "time")] Some(i) => DataValue::Time32(UNIX_TIME.num_seconds_from_midnight(), *i), + #[cfg(feature = "time")] None => DataValue::Time32(UNIX_TIME.num_seconds_from_midnight(), 0), + #[cfg(not(feature = "time"))] + Some(i) => DataValue::Time32(0, *i), + #[cfg(not(feature = "time"))] + None => DataValue::Time32(0, 0), }, - LogicalType::TimeStamp(precision, zone) => match precision { - Some(3) => DataValue::Time64(UNIX_DATETIME.and_utc().timestamp_millis(), 3, *zone), - Some(6) => DataValue::Time64(UNIX_DATETIME.and_utc().timestamp_micros(), 6, *zone), - Some(9) => { - if let Some(value) = UNIX_DATETIME.and_utc().timestamp_nanos_opt() { - DataValue::Time64(value, 9, *zone) - } else { - unreachable!() + LogicalType::TimeStamp(precision, zone) => { + #[cfg(feature = "time")] + { + match precision { + Some(3) => { + DataValue::Time64(UNIX_DATETIME.and_utc().timestamp_millis(), 3, *zone) + } + Some(6) => { + DataValue::Time64(UNIX_DATETIME.and_utc().timestamp_micros(), 6, *zone) + } + Some(9) => { + if let Some(value) = UNIX_DATETIME.and_utc().timestamp_nanos_opt() { + DataValue::Time64(value, 9, *zone) + } else { + unreachable!() + } + } + None => DataValue::Time64(UNIX_DATETIME.and_utc().timestamp(), 0, *zone), + _ => unreachable!(), } } - None => DataValue::Time64(UNIX_DATETIME.and_utc().timestamp(), 0, *zone), - _ => unreachable!(), - }, + #[cfg(not(feature = "time"))] + { + DataValue::Time64(0, precision.unwrap_or_default(), *zone) + } + } + #[cfg(feature = "decimal")] LogicalType::Decimal(_, _) => DataValue::Decimal(Decimal::new(0, 0)), + #[cfg(not(feature = "decimal"))] + LogicalType::Decimal(_, _) => unreachable!("DECIMAL requires the `decimal` feature"), LogicalType::Tuple(types) => { let values = types.iter().map(DataValue::init).collect_vec(); @@ -630,6 +696,7 @@ impl DataValue { DataValue::Date64(_) => LogicalType::DateTime, DataValue::Time32(..) => LogicalType::Time(None), DataValue::Time64(..) => LogicalType::TimeStamp(None, false), + #[cfg(feature = "decimal")] DataValue::Decimal(_) => LogicalType::Decimal(None, None), DataValue::Tuple(values, ..) => { let types = values.iter().map(|v| v.logical_type()).collect_vec(); @@ -745,15 +812,16 @@ impl DataValue { DataValue::Null => (), DataValue::Int8(v) => encode_u!(b, *v as u8 ^ 0x80_u8), DataValue::Int16(v) => encode_u!(b, *v as u16 ^ 0x8000_u16), - DataValue::Int32(v) | DataValue::Date32(v) => { - encode_u!(b, *v as u32 ^ 0x80000000_u32) - } - DataValue::Int64(v) | DataValue::Date64(v) | DataValue::Time64(v, ..) => { + DataValue::Int32(v) => encode_u!(b, *v as u32 ^ 0x80000000_u32), + DataValue::Date32(v) => encode_u!(b, *v as u32 ^ 0x80000000_u32), + DataValue::Int64(v) => encode_u!(b, *v as u64 ^ 0x8000000000000000_u64), + DataValue::Date64(v) | DataValue::Time64(v, ..) => { encode_u!(b, *v as u64 ^ 0x8000000000000000_u64) } DataValue::UInt8(v) => encode_u!(b, v), DataValue::UInt16(v) => encode_u!(b, v), - DataValue::UInt32(v) | DataValue::Time32(v, ..) => encode_u!(b, v), + DataValue::UInt32(v) => encode_u!(b, v), + DataValue::Time32(v, ..) => encode_u!(b, v), DataValue::UInt64(v) => encode_u!(b, v), DataValue::Utf8 { value: v, .. } => Self::encode_string(b, v.as_bytes()), DataValue::Boolean(v) => b.push_byte(if *v { b'1' } else { b'0' }), @@ -779,6 +847,7 @@ impl DataValue { encode_u!(b, u); } + #[cfg(feature = "decimal")] DataValue::Decimal(v) => Self::serialize_decimal(*v, b)?, DataValue::Tuple(values, is_upper) => { let last = values.len() - 1; @@ -830,14 +899,34 @@ impl DataValue { let u = decode_u!(reader, u16); Ok(DataValue::Int16((u ^ 0x8000) as i16)) } - LogicalType::Integer | LogicalType::Date | LogicalType::Time(_) => { + LogicalType::Integer => { let u = decode_u!(reader, u32); Ok(DataValue::Int32((u ^ 0x8000_0000) as i32)) } - LogicalType::Bigint | LogicalType::DateTime | LogicalType::TimeStamp(..) => { + LogicalType::Date => { + let u = decode_u!(reader, u32); + Ok(DataValue::Date32((u ^ 0x8000_0000) as i32)) + } + LogicalType::Time(precision) => { + let u = decode_u!(reader, u32); + Ok(DataValue::Time32(u, precision.unwrap_or_default())) + } + LogicalType::Bigint => { let u = decode_u!(reader, u64); Ok(DataValue::Int64((u ^ 0x8000_0000_0000_0000) as i64)) } + LogicalType::DateTime => { + let u = decode_u!(reader, u64); + Ok(DataValue::Date64((u ^ 0x8000_0000_0000_0000) as i64)) + } + LogicalType::TimeStamp(precision, zone) => { + let u = decode_u!(reader, u64); + Ok(DataValue::Time64( + (u ^ 0x8000_0000_0000_0000) as i64, + precision.unwrap_or_default(), + *zone, + )) + } LogicalType::UTinyint => Ok(DataValue::UInt8(decode_u!(reader, u8))), LogicalType::USmallint => Ok(DataValue::UInt16(decode_u!(reader, u16))), LogicalType::UInteger => Ok(DataValue::UInt32(decode_u!(reader, u32))), @@ -880,7 +969,12 @@ impl DataValue { ty: Utf8Type::Fixed(*len), unit: *unit, }), + #[cfg(feature = "decimal")] LogicalType::Decimal(..) => Ok(DataValue::Decimal(Self::deserialize_decimal(reader)?)), + #[cfg(not(feature = "decimal"))] + LogicalType::Decimal(..) => Err(DatabaseError::UnsupportedStmt( + "DECIMAL requires the `decimal` feature".to_string(), + )), LogicalType::Tuple(tys) => { let mut collector = TupleCollector::new(tuple_mapping, tys.len()); @@ -894,6 +988,7 @@ impl DataValue { } // https://github.com/risingwavelabs/memcomparable/blob/main/src/ser.rs#L468 + #[cfg(feature = "decimal")] pub fn serialize_decimal( decimal: Decimal, bytes: &mut B, @@ -939,6 +1034,7 @@ impl DataValue { Ok(()) } + #[cfg(feature = "decimal")] fn decimal_e_m(decimal: Decimal) -> (i8, Vec) { if decimal.is_zero() { return (0, vec![]); @@ -1015,6 +1111,7 @@ impl DataValue { (e100 as i8, byte_array) } + #[cfg(feature = "decimal")] pub fn deserialize_decimal(mut reader: R) -> Result { // decode exponent let flag = reader.read_u8()?; @@ -1090,9 +1187,20 @@ impl DataValue { if &from == to { return Ok(self); } - let evaluator = cast_create(Cow::Owned(from), Cow::Borrowed(to))?; - evaluator.eval_cast(&self) + match (self, to) { + (DataValue::Null, _) => Ok(DataValue::Null), + (DataValue::Utf8 { value, .. }, LogicalType::Char(len, unit)) => { + to_char(value, *len, *unit) + } + (DataValue::Utf8 { value, .. }, LogicalType::Varchar(len, unit)) => { + to_varchar(value, *len, *unit) + } + (value, _) => { + let evaluator = cast_create(Cow::Owned(from), Cow::Borrowed(to))?; + evaluator.eval(&value) + } + } } #[inline] @@ -1123,15 +1231,7 @@ impl DataValue { Some(0) } - #[inline] - pub(crate) fn values_to_tuple(mut values: Vec) -> Option { - if values.len() > 1 { - Some(DataValue::Tuple(values, false)) - } else { - values.pop() - } - } - + #[cfg(feature = "decimal")] pub(crate) fn decimal_round_i(option: &Option, decimal: &mut Decimal) { if let Some(scale) = option { let new_decimal = decimal.trunc_with_scale(*scale as u32); @@ -1139,6 +1239,7 @@ impl DataValue { } } + #[cfg(feature = "decimal")] pub(crate) fn decimal_round_f(option: &Option, decimal: &mut Decimal) { if let Some(scale) = option { let new_decimal = decimal.round_dp_with_strategy( @@ -1149,6 +1250,54 @@ impl DataValue { } } + #[cfg(feature = "decimal")] + fn decimal_format(v: &Decimal) -> String { + v.to_string() + } +} + +#[cfg(feature = "time")] +impl DataValue { + pub fn date(&self) -> Option { + if let DataValue::Date32(val) = self { + NaiveDate::from_num_days_from_ce_opt(*val) + } else { + None + } + } + + pub fn datetime(&self) -> Option { + if let DataValue::Date64(val) = self { + DateTime::from_timestamp(*val, 0).map(|dt| dt.naive_utc()) + } else { + None + } + } + + pub fn time(&self) -> Option { + if let DataValue::Time32(val, ..) = self { + NaiveTime::from_num_seconds_from_midnight_opt(*val, 0) + } else { + None + } + } + + pub(crate) fn format_date(value: i32) -> Option { + Self::date_format(value).map(|fmt| format!("{fmt}")) + } + + pub(crate) fn format_datetime(value: i64) -> Option { + Self::date_time_format(value).map(|fmt| format!("{fmt}")) + } + + pub(crate) fn format_time(value: u32, precision: u64) -> Option { + Self::time_format(value, precision).map(|fmt| format!("{fmt}")) + } + + pub(crate) fn format_timestamp(value: i64, precision: u64) -> Option { + Self::time_stamp_format(value, precision, false).map(|fmt| format!("{fmt}")) + } + fn date_format<'a>(v: i32) -> Option>> { NaiveDate::from_num_days_from_ce_opt(v).map(|date| date.format(DATE_FMT)) } @@ -1172,10 +1321,6 @@ impl DataValue { .map(|date_time| date_time.format(TIME_STAMP_FMT_WITHOUT_ZONE)) } - fn decimal_format(v: &Decimal) -> String { - v.to_string() - } - pub fn timestamp_precision(v: DateTime, precision: u64) -> i64 { match precision { 3 => v.timestamp_millis(), @@ -1230,6 +1375,7 @@ impl_scalar!(u8, UInt8); impl_scalar!(u16, UInt16); impl_scalar!(u32, UInt32); impl_scalar!(u64, UInt64); +#[cfg(feature = "decimal")] impl_scalar!(Decimal, Decimal); impl From for DataValue { @@ -1288,12 +1434,14 @@ impl From> for DataValue { } } +#[cfg(feature = "time")] impl From<&NaiveDate> for DataValue { fn from(value: &NaiveDate) -> Self { DataValue::Date32(value.num_days_from_ce()) } } +#[cfg(feature = "time")] impl From> for DataValue { fn from(value: Option<&NaiveDate>) -> Self { if let Some(value) = value { @@ -1304,12 +1452,14 @@ impl From> for DataValue { } } +#[cfg(feature = "time")] impl From<&NaiveDateTime> for DataValue { fn from(value: &NaiveDateTime) -> Self { DataValue::Date64(value.and_utc().timestamp()) } } +#[cfg(feature = "time")] impl From> for DataValue { fn from(value: Option<&NaiveDateTime>) -> Self { if let Some(value) = value { @@ -1320,6 +1470,7 @@ impl From> for DataValue { } } +#[cfg(feature = "time")] impl From<&NaiveTime> for DataValue { fn from(value: &NaiveTime) -> Self { DataValue::Time32( @@ -1329,6 +1480,7 @@ impl From<&NaiveTime> for DataValue { } } +#[cfg(feature = "time")] impl From> for DataValue { fn from(value: Option<&NaiveTime>) -> Self { if let Some(value) = value { @@ -1342,34 +1494,6 @@ impl From> for DataValue { } } -impl TryFrom<&sqlparser::ast::Value> for DataValue { - type Error = DatabaseError; - - fn try_from(value: &sqlparser::ast::Value) -> Result { - Ok(match value { - sqlparser::ast::Value::Number(n, _) => { - // use i32 to handle most cases - if let Ok(v) = n.parse::() { - v.into() - } else if let Ok(v) = n.parse::() { - v.into() - } else if let Ok(v) = n.parse::() { - v.into() - } else if let Ok(v) = n.parse::() { - v.into() - } else { - return Err(DatabaseError::InvalidValue(n.to_string())); - } - } - sqlparser::ast::Value::SingleQuotedString(s) - | sqlparser::ast::Value::DoubleQuotedString(s) => s.clone().into(), - sqlparser::ast::Value::Boolean(b) => (*b).into(), - sqlparser::ast::Value::Null => Self::Null, - v => return Err(DatabaseError::UnsupportedStmt(format!("{v:?}"))), - }) - } -} - macro_rules! format_float_option { ($F:expr, $EXPR:expr) => {{ let formatted_string = format!("{:?}", $EXPR); @@ -1399,16 +1523,41 @@ impl fmt::Display for DataValue { DataValue::UInt64(e) => write!(f, "{e}")?, DataValue::Utf8 { value: e, .. } => write!(f, "{e}")?, DataValue::Null => write!(f, "null")?, - DataValue::Date32(e) => write!(f, "{}", DataValue::date_format(*e).unwrap())?, - DataValue::Date64(e) => write!(f, "{}", DataValue::date_time_format(*e).unwrap())?, + DataValue::Date32(e) => { + #[cfg(feature = "time")] + write!(f, "{}", DataValue::date_format(*e).unwrap())?; + #[cfg(not(feature = "time"))] + write!(f, "{e}")?; + } + DataValue::Date64(e) => { + #[cfg(feature = "time")] + write!(f, "{}", DataValue::date_time_format(*e).unwrap())?; + #[cfg(not(feature = "time"))] + write!(f, "{e}")?; + } DataValue::Time32(e, precision) => { - write!(f, "{}", DataValue::time_format(*e, *precision).unwrap())? + #[cfg(feature = "time")] + write!(f, "{}", DataValue::time_format(*e, *precision).unwrap())?; + #[cfg(not(feature = "time"))] + { + let _ = precision; + write!(f, "{e}")?; + } + } + DataValue::Time64(e, precision, zone) => { + #[cfg(feature = "time")] + write!( + f, + "{}", + DataValue::time_stamp_format(*e, *precision, *zone).unwrap() + )?; + #[cfg(not(feature = "time"))] + { + let _ = (precision, zone); + write!(f, "{e}")?; + } } - DataValue::Time64(e, precision, zone) => write!( - f, - "{}", - DataValue::time_stamp_format(*e, *precision, *zone).unwrap() - )?, + #[cfg(feature = "decimal")] DataValue::Decimal(e) => write!(f, "{}", DataValue::decimal_format(e))?, DataValue::Tuple(values, ..) => { write!(f, "(")?; @@ -1447,6 +1596,7 @@ impl fmt::Debug for DataValue { DataValue::Date64(_) => write!(f, "Date64({self})"), DataValue::Time32(..) => write!(f, "Time32({self})"), DataValue::Time64(..) => write!(f, "Time64({self})"), + #[cfg(feature = "decimal")] DataValue::Decimal(_) => write!(f, "Decimal({self})"), DataValue::Tuple(..) => { write!(f, "Tuple({self}")?; diff --git a/src/utils/lru.rs b/src/utils/lru.rs deleted file mode 100644 index bc8417dd..00000000 --- a/src/utils/lru.rs +++ /dev/null @@ -1,444 +0,0 @@ -// Copyright 2024 KipData/KiteSQL -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::errors::DatabaseError; -use parking_lot::Mutex; -use std::borrow::Borrow; -use std::cmp::Ordering; -use std::collections::hash_map::{Iter, RandomState}; -use std::collections::HashMap; -use std::hash::{BuildHasher, Hash, Hasher}; -use std::marker::PhantomData; -use std::ops::{Deref, DerefMut}; -use std::ptr::NonNull; - -// 只读Node操作裸指针 -// https://course.rs/advance/concurrency-with-threads/send-sync.html#:~:text=%E5%AE%89%E5%85%A8%E7%9A%84%E4%BD%BF%E7%94%A8%E3%80%82-,%E4%B8%BA%E8%A3%B8%E6%8C%87%E9%92%88%E5%AE%9E%E7%8E%B0Send,-%E4%B8%8A%E9%9D%A2%E6%88%91%E4%BB%AC%E6%8F%90%E5%88%B0 -// 通过只读数据已保证线程安全 -struct NodeReadPtr(NonNull>); - -unsafe impl Send for NodeReadPtr {} -unsafe impl Sync for NodeReadPtr {} - -impl Clone for NodeReadPtr { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for NodeReadPtr {} - -impl Deref for NodeReadPtr { - type Target = NonNull>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for NodeReadPtr { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -unsafe impl Send for SharedLruCache {} -unsafe impl Sync for SharedLruCache {} - -pub struct SharedLruCache { - shared_vec: Vec>>, - hasher: S, -} - -struct Node { - key: K, - value: V, - prev: Option>, - next: Option>, -} - -struct KeyRef(NodeReadPtr); - -impl Borrow for KeyRef { - fn borrow(&self) -> &K { - unsafe { &self.0.as_ref().key } - } -} - -impl Hash for KeyRef { - fn hash(&self, state: &mut H) { - unsafe { self.0.as_ref().key.hash(state) } - } -} - -impl Eq for KeyRef {} - -impl PartialEq for KeyRef { - #[allow(clippy::unconditional_recursion)] - fn eq(&self, other: &Self) -> bool { - unsafe { self.0.as_ref().key.eq(&other.0.as_ref().key) } - } -} - -impl PartialOrd for KeyRef { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for KeyRef { - fn cmp(&self, other: &Self) -> Ordering { - unsafe { self.0.as_ref().key.cmp(&other.0.as_ref().key) } - } -} - -/// LRU缓存 -/// 参考知乎中此文章的实现: -/// https://zhuanlan.zhihu.com/p/466409120 -pub struct LruCache { - head: Option>, - tail: Option>, - inner: HashMap, NodeReadPtr>, - cap: usize, - marker: PhantomData>, -} - -impl Node { - fn new(key: K, value: V) -> Self { - Self { - key, - value, - prev: None, - next: None, - } - } -} - -impl SharedLruCache { - #[inline] - pub fn new(cap: usize, shared_size: usize, hasher: S) -> Result { - let mut shared_vec = Vec::with_capacity(shared_size); - if !cap.is_multiple_of(shared_size) { - return Err(DatabaseError::SharedNotAlign); - } - let shared_cap = cap / shared_size; - for _ in 0..shared_size { - shared_vec.push(Mutex::new(LruCache::new(shared_cap)?)); - } - - Ok(SharedLruCache { shared_vec, hasher }) - } - - #[inline] - pub fn get(&self, key: &K) -> Option<&V> { - self.shard(key) - .lock() - .get_node(key) - .map(|node| unsafe { &node.as_ref().value }) - } - - #[inline] - pub fn put(&self, key: K, value: V) -> Option { - self.shard(&key).lock().put(key, value) - } - - #[inline] - pub fn remove(&self, key: &K) -> Option { - self.shard(key).lock().remove(key) - } - - #[inline] - pub fn is_empty(&self) -> bool { - for lru in &self.shared_vec { - if !lru.lock().is_empty() { - return false; - } - } - true - } - - #[inline] - pub fn get_or_insert(&self, key: K, fn_once: F) -> Result<&V, DatabaseError> - where - F: FnOnce(&K) -> Result, - { - self.shard(&key) - .lock() - .get_or_insert_node(key, fn_once) - .map(|node| unsafe { &node.as_ref().value }) - } - - fn shared_size(&self) -> usize { - self.shared_vec.len() - } - - /// 通过key获取hash值后对其求余获取对应分片 - fn shard(&self, key: &K) -> &Mutex> { - let mut hasher = self.hasher.build_hasher(); - key.hash(&mut hasher); - #[allow(clippy::manual_hash_one)] - &self.shared_vec[hasher.finish() as usize % self.shared_size()] - } -} - -impl LruCache { - #[inline] - pub fn new(cap: usize) -> Result { - if cap < 1 { - return Err(DatabaseError::CacheSizeOverFlow); - } - - Ok(Self { - head: None, - tail: None, - inner: HashMap::new(), - cap, - marker: PhantomData, - }) - } - - /// 移除节点 - fn detach(&mut self, mut node: NodeReadPtr) { - unsafe { - match node.as_mut().prev { - Some(mut prev) => { - prev.as_mut().next = node.as_ref().next; - } - None => { - self.head = node.as_ref().next; - } - } - match node.as_mut().next { - Some(mut next) => { - next.as_mut().prev = node.as_ref().prev; - } - None => { - self.tail = node.as_ref().prev; - } - } - - node.as_mut().prev = None; - node.as_mut().next = None; - } - } - - /// 添加节点至头部 - fn attach(&mut self, mut node: NodeReadPtr) { - match self.head { - Some(mut head) => { - unsafe { - head.as_mut().prev = Some(node); - node.as_mut().next = Some(head); - node.as_mut().prev = None; - } - self.head = Some(node); - } - None => { - unsafe { - node.as_mut().prev = None; - node.as_mut().next = None; - } - self.head = Some(node); - self.tail = Some(node); - } - } - } - - /// 判断并驱逐节点 - fn expulsion(&mut self) { - if let Some(tail) = self.tail { - if self.inner.len() >= self.cap { - self.detach(tail); - let _ignore = self.inner.remove(&KeyRef(tail)); - } - } - } - - #[inline] - #[allow(clippy::manual_inspect)] - pub fn put(&mut self, key: K, value: V) -> Option { - let node = NodeReadPtr(Box::leak(Box::new(Node::new(key, value))).into()); - let old_node = self.inner.remove(&KeyRef(node)).map(|node| { - self.detach(node); - node - }); - self.expulsion(); - self.attach(node); - let _ignore1 = self.inner.insert(KeyRef(node), node); - old_node.map(|node| unsafe { - let node: Box> = Box::from_raw(node.as_ptr()); - node.value - }) - } - - #[allow(dead_code)] - fn get_node(&mut self, key: &K) -> Option> { - if let Some(node) = self.inner.get(key) { - let node = *node; - self.detach(node); - self.attach(node); - Some(node) - } else { - None - } - } - - #[allow(dead_code)] - #[inline] - pub fn get(&mut self, key: &K) -> Option<&V> { - if let Some(node) = self.inner.get(key) { - let node = *node; - self.detach(node); - self.attach(node); - unsafe { Some(&node.as_ref().value) } - } else { - None - } - } - - #[inline] - pub fn remove(&mut self, key: &K) -> Option { - self.inner.remove(key).map(|node| { - self.detach(node); - unsafe { - let node: Box> = Box::from_raw(node.as_ptr()); - node.value - } - }) - } - - fn get_or_insert_node( - &mut self, - key: K, - fn_once: F, - ) -> Result, DatabaseError> - where - F: FnOnce(&K) -> Result, - { - if let Some(node) = self.inner.get(&key) { - let node = *node; - self.detach(node); - self.attach(node); - Ok(node) - } else { - let value = fn_once(&key)?; - let node = NodeReadPtr(Box::leak(Box::new(Node::new(key, value))).into()); - self.inner.remove(&KeyRef(node)).inspect(|&node| { - self.detach(node); - }); - self.expulsion(); - self.attach(node); - let _ignore1 = self.inner.insert(KeyRef(node), node); - Ok(node) - } - } - - #[allow(dead_code)] - #[inline] - pub fn get_or_insert(&mut self, key: K, fn_once: F) -> Result<&V, DatabaseError> - where - F: FnOnce(&K) -> Result, - { - self.get_or_insert_node(key, fn_once) - .map(|node| unsafe { &node.as_ref().value }) - } - - #[allow(dead_code)] - #[inline] - pub fn len(&self) -> usize { - self.inner.len() - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - #[allow(dead_code)] - #[inline] - pub fn iter(&self) -> LruCacheIter<'_, K, V> { - LruCacheIter { - inner: self.inner.iter(), - } - } -} - -pub struct LruCacheIter<'a, K, V> { - inner: Iter<'a, KeyRef, NodeReadPtr>, -} - -impl<'a, K, V> Iterator for LruCacheIter<'a, K, V> { - type Item = (&'a K, &'a V); - - #[inline] - fn next(&mut self) -> Option { - self.inner - .next() - .map(|(_, node)| unsafe { (&node.as_ref().key, &node.as_ref().value) }) - } -} - -impl Drop for LruCache { - #[inline] - fn drop(&mut self) { - while let Some(node) = self.head.take() { - unsafe { - self.head = node.as_ref().next; - drop(Box::from_raw(node.as_ptr())) - } - } - } -} - -#[cfg(all(test, not(target_arch = "wasm32")))] -mod tests { - use crate::utils::lru::{LruCache, SharedLruCache}; - use std::collections::hash_map::RandomState; - use std::collections::HashSet; - - #[test] - fn test_lru_cache() { - let mut lru = LruCache::new(3).unwrap(); - assert!(lru.is_empty()); - assert_eq!(lru.put(1, 10), None); - assert_eq!(lru.put(2, 20), None); - assert_eq!(lru.put(3, 30), None); - assert_eq!(lru.get(&1), Some(&10)); - assert_eq!(lru.put(2, 200), Some(20)); - assert_eq!(lru.put(4, 40), None); - assert_eq!(lru.get(&2), Some(&200)); - assert_eq!(lru.get(&3), None); - - assert_eq!(lru.get_or_insert(9, |_| Ok(9)).unwrap(), &9); - - assert_eq!(lru.len(), 3); - assert!(!lru.is_empty()); - - let mut set = HashSet::from([(&9, &9), (&2, &200), (&4, &40)]); - - for item in lru.iter() { - assert!(set.remove(&item)) - } - } - - #[test] - fn test_shared_cache() { - let lru = SharedLruCache::new(4, 2, RandomState::default()).unwrap(); - assert!(lru.is_empty()); - assert_eq!(lru.put(1, 10), None); - assert_eq!(lru.get(&1), Some(&10)); - assert!(!lru.is_empty()); - assert_eq!(lru.get_or_insert(9, |_| Ok(9)).unwrap(), &9); - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs deleted file mode 100644 index 32fea1d5..00000000 --- a/src/utils/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2024 KipData/KiteSQL -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -pub(crate) mod lru; diff --git a/src/wasm.rs b/src/wasm.rs index 639618cb..c0f68776 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -17,32 +17,116 @@ use crate::db::{DataBaseBuilder, Database, DatabaseIter}; use crate::storage::memory::MemoryStorage; use crate::types::tuple::Tuple; -use crate::types::value::DataValue; -use serde::Serialize; +use crate::types::value::{DataValue, Utf8Type}; +use crate::types::CharLengthUnits; +use js_sys::{Array, Object, Reflect}; use wasm_bindgen::prelude::*; -#[derive(Serialize)] -struct WasmRow { - pk: Option, - values: Vec, +fn to_js_err(err: impl ToString) -> JsValue { + js_sys::Error::new(&err.to_string()).into() } -#[derive(Serialize)] -struct WasmSchemaColumn { - name: String, - datatype: String, - nullable: bool, +fn set_prop(object: &Object, key: &str, value: JsValue) -> Result<(), JsValue> { + Reflect::set(object, &JsValue::from_str(key), &value)?; + Ok(()) } -fn to_js_err(err: impl ToString) -> JsValue { - js_sys::Error::new(&err.to_string()).into() +fn data_value_to_js(value: &DataValue) -> Result { + match value { + DataValue::Null => Ok(JsValue::NULL), + DataValue::Boolean(value) => Ok(JsValue::from_bool(*value)), + DataValue::Float32(value) => Ok(JsValue::from_f64(value.0 as f64)), + DataValue::Float64(value) => Ok(JsValue::from_f64(value.0)), + DataValue::Int8(value) => Ok(JsValue::from_f64(*value as f64)), + DataValue::Int16(value) => Ok(JsValue::from_f64(*value as f64)), + DataValue::Int32(value) => Ok(JsValue::from_f64(*value as f64)), + DataValue::Int64(value) => Ok(JsValue::from_f64(*value as f64)), + DataValue::UInt8(value) => Ok(JsValue::from_f64(*value as f64)), + DataValue::UInt16(value) => Ok(JsValue::from_f64(*value as f64)), + DataValue::UInt32(value) => Ok(JsValue::from_f64(*value as f64)), + DataValue::UInt64(value) => Ok(JsValue::from_f64(*value as f64)), + DataValue::Utf8 { value, ty, unit } => { + let object = Object::new(); + set_prop(&object, "value", JsValue::from_str(value))?; + set_prop(&object, "type", utf8_type_to_js(ty)?)?; + set_prop(&object, "unit", char_length_units_to_js(*unit))?; + Ok(object.into()) + } + DataValue::Date32(value) => Ok(JsValue::from_f64(*value as f64)), + DataValue::Date64(value) => Ok(JsValue::from_f64(*value as f64)), + DataValue::Time32(value, precision) => { + let object = Object::new(); + set_prop(&object, "value", JsValue::from_f64(*value as f64))?; + set_prop(&object, "precision", JsValue::from_f64(*precision as f64))?; + Ok(object.into()) + } + DataValue::Time64(value, precision, with_tz) => { + let object = Object::new(); + set_prop(&object, "value", JsValue::from_f64(*value as f64))?; + set_prop(&object, "precision", JsValue::from_f64(*precision as f64))?; + set_prop(&object, "withTimezone", JsValue::from_bool(*with_tz))?; + Ok(object.into()) + } + #[cfg(feature = "decimal")] + DataValue::Decimal(value) => Ok(JsValue::from_str(&value.to_string())), + DataValue::Tuple(values, is_upper) => { + let object = Object::new(); + set_prop(&object, "values", data_values_to_js(values)?)?; + set_prop(&object, "isUpper", JsValue::from_bool(*is_upper))?; + Ok(object.into()) + } + } +} + +fn utf8_type_to_js(ty: &Utf8Type) -> Result { + let object = Object::new(); + match ty { + Utf8Type::Variable(len) => { + set_prop(&object, "kind", JsValue::from_str("variable"))?; + set_prop( + &object, + "len", + len.map(|len| JsValue::from_f64(len as f64)) + .unwrap_or(JsValue::NULL), + )?; + } + Utf8Type::Fixed(len) => { + set_prop(&object, "kind", JsValue::from_str("fixed"))?; + set_prop(&object, "len", JsValue::from_f64(*len as f64))?; + } + } + Ok(object.into()) +} + +fn char_length_units_to_js(unit: CharLengthUnits) -> JsValue { + JsValue::from_str(match unit { + CharLengthUnits::Characters => "characters", + CharLengthUnits::Octets => "octets", + }) } -fn tuple_to_wasm_row(tuple: &Tuple) -> WasmRow { - WasmRow { - pk: tuple.pk.clone(), - values: tuple.values.clone(), +fn data_values_to_js(values: &[DataValue]) -> Result { + let array = Array::new(); + for value in values { + array.push(&data_value_to_js(value)?); } + Ok(array.into()) +} + +fn tuple_to_wasm_row(tuple: &Tuple) -> Result { + let object = Object::new(); + set_prop( + &object, + "pk", + tuple + .pk + .as_ref() + .map(data_value_to_js) + .transpose()? + .unwrap_or(JsValue::NULL), + )?; + set_prop(&object, "values", data_values_to_js(&tuple.values)?)?; + Ok(object.into()) } #[wasm_bindgen] @@ -89,6 +173,14 @@ impl WasmDatabase { while iter.next_borrowed_tuple().map_err(to_js_err)?.is_some() {} Ok(()) } + + pub fn ddl(&mut self, sql: &str) -> Result<(), JsValue> { + self.inner.ddl(sql).map_err(to_js_err) + } + + pub fn analyze(&mut self, table_name: &str) -> Result<(), JsValue> { + self.inner.analyze(table_name).map_err(to_js_err) + } } #[wasm_bindgen] @@ -101,8 +193,7 @@ impl WasmResultIter { .as_mut() .ok_or_else(|| to_js_err("iterator already consumed"))?; match iter.next_borrowed_tuple().map_err(to_js_err)? { - Some(tuple) => serde_wasm_bindgen::to_value(&tuple_to_wasm_row(tuple)) - .map_err(|e| to_js_err(format!("serialize row: {e}"))), + Some(tuple) => tuple_to_wasm_row(tuple), None => Ok(JsValue::undefined()), } } @@ -114,17 +205,21 @@ impl WasmResultIter { .inner .as_ref() .ok_or_else(|| to_js_err("iterator already consumed"))?; - let columns: Vec = iter - .schema() - .iter() - .map(|col| WasmSchemaColumn { - name: col.name().to_string(), - datatype: col.datatype().to_string(), - nullable: col.nullable(), - }) - .collect(); - serde_wasm_bindgen::to_value(&columns) - .map_err(|e| to_js_err(format!("serialize schema: {e}"))) + iter.schema(|schema| { + let columns = Array::new(); + for col in schema.iter() { + let object = Object::new(); + set_prop(&object, "name", JsValue::from_str(col.name()))?; + set_prop( + &object, + "datatype", + JsValue::from_str(&col.datatype().to_string()), + )?; + set_prop(&object, "nullable", JsValue::from_bool(col.nullable()))?; + columns.push(&object); + } + Ok(columns.into()) + }) } /// Collect all remaining rows into an array and finish the iterator. @@ -134,12 +229,12 @@ impl WasmResultIter { .inner .take() .ok_or_else(|| to_js_err("iterator already consumed"))?; - let mut rows = Vec::new(); + let rows = Array::new(); while let Some(tuple) = iter.next_borrowed_tuple().map_err(to_js_err)? { - rows.push(tuple_to_wasm_row(tuple)); + rows.push(&tuple_to_wasm_row(tuple)?); } iter.done().map_err(to_js_err)?; - serde_wasm_bindgen::to_value(&rows).map_err(|e| to_js_err(format!("serialize rows: {e}"))) + Ok(rows.into()) } /// Finish iteration early and commit any work. diff --git a/tests/macros-test/Cargo.toml b/tests/macros-test/Cargo.toml index 036e176e..153f1e91 100644 --- a/tests/macros-test/Cargo.toml +++ b/tests/macros-test/Cargo.toml @@ -4,9 +4,8 @@ version = "0.4.0" edition = "2021" [dev-dependencies] -"kite_sql" = { path = "../..", features = ["macros", "orm"] } +"kite_sql" = { path = "../..", features = ["macros", "orm", "decimal"] } lazy_static = { version = "1" } -serde = { version = "1", features = ["derive", "rc"] } -rust_decimal = { version = "1" } +rust_decimal = { version = "1", default-features = false, features = ["std"] } sqlparser = { version = "0.61", default-features = false, features = ["std"] } tempfile = { version = "3.10" } diff --git a/tests/macros-test/src/main.rs b/tests/macros-test/src/main.rs index 5e24997a..c8600efe 100644 --- a/tests/macros-test/src/main.rs +++ b/tests/macros-test/src/main.rs @@ -16,7 +16,7 @@ fn main() {} #[cfg(test)] mod test { - use kite_sql::catalog::column::{ColumnCatalog, ColumnDesc, ColumnRef, ColumnRelation}; + use kite_sql::catalog::column::{ColumnCatalog, ColumnDesc, ColumnRelation}; use kite_sql::catalog::table::TableName; use kite_sql::db::{DataBaseBuilder, Database, ResultIter}; use kite_sql::errors::DatabaseError; @@ -25,26 +25,25 @@ mod test { use kite_sql::expression::function::FunctionSummary; use kite_sql::expression::BinaryOperator; use kite_sql::expression::ScalarExpression; - use kite_sql::orm::{case_when, count_all, func, max, min, sum, QueryValue}; + use kite_sql::orm::OrmQueryResultExt; + use kite_sql::planner::{MetaArena, PlanArena, TableArena, TableArenaCell}; use kite_sql::storage::rocksdb::RocksStorage; use kite_sql::types::evaluator::binary_create; - use kite_sql::types::tuple::{SchemaRef, Tuple}; + use kite_sql::types::tuple::{Schema, SchemaView, Tuple}; use kite_sql::types::value::{DataValue, Utf8Type}; use kite_sql::types::{CharLengthUnits, LogicalType}; use kite_sql::{from_tuple, scala_function, table_function, Model, Projection}; use rust_decimal::Decimal; - use sqlparser::ast::DataType as SqlDataType; - use std::sync::Arc; use tempfile::TempDir; - fn build_tuple() -> (Tuple, SchemaRef) { - let schema_ref = Arc::new(vec![ - ColumnRef::from(ColumnCatalog::new( + fn build_tuple(arena: &mut impl MetaArena) -> (Tuple, Schema) { + let schema = vec![ + arena.alloc_column(ColumnCatalog::new( "c1".to_string(), false, ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), )), - ColumnRef::from(ColumnCatalog::new( + arena.alloc_column(ColumnCatalog::new( "c2".to_string(), false, ColumnDesc::new( @@ -55,7 +54,7 @@ mod test { ) .unwrap(), )), - ]); + ]; let values = vec![ DataValue::Int32(9), DataValue::Utf8 { @@ -65,7 +64,7 @@ mod test { }, ]; - (Tuple::new(None, values), schema_ref) + (Tuple::new(None, values), schema) } fn build_test_database() -> Result<(TempDir, Database), DatabaseError> { @@ -75,6 +74,38 @@ mod test { Ok((temp_dir, database)) } + fn create_model_table( + database: &mut Database, + ) -> Result<(), DatabaseError> { + database.create_table::() + } + + fn create_model_table_if_not_exists( + database: &mut Database, + ) -> Result<(), DatabaseError> { + database.create_table_if_not_exists::() + } + + fn migrate_model( + database: &mut Database, + ) -> Result<(), DatabaseError> { + database.migrate::() + } + + fn drop_model_index( + database: &mut Database, + index_name: &str, + ) -> Result<(), DatabaseError> { + database.drop_index::(index_name) + } + + fn drop_model_index_if_exists( + database: &mut Database, + index_name: &str, + ) -> Result<(), DatabaseError> { + database.drop_index_if_exists::(index_name) + } + #[derive(Default, Debug, PartialEq)] struct MyStruct { c1: i32, @@ -251,8 +282,11 @@ mod test { #[test] fn test_from_tuple() { - let (tuple, schema_ref) = build_tuple(); - let my_struct = MyStruct::from((&schema_ref, tuple)); + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); + let (tuple, schema) = build_tuple(&mut plan_arena); + let schema = SchemaView::new(&schema, &plan_arena); + let my_struct = MyStruct::from((&schema, tuple)); println!("{:?}", my_struct); @@ -262,31 +296,23 @@ mod test { #[test] fn test_model_mapping() { - let mut tuple_and_schema = build_tuple(); - tuple_and_schema.1 = Arc::new(vec![ - ColumnRef::from(ColumnCatalog::new( - "c1".to_string(), - false, - ColumnDesc::new(LogicalType::Integer, Some(0), false, None).unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( - "c2".to_string(), - false, - ColumnDesc::new( - LogicalType::Varchar(None, CharLengthUnits::Characters), - None, - false, - None, - ) - .unwrap(), - )), - ColumnRef::from(ColumnCatalog::new( - "age".to_string(), - true, - ColumnDesc::new(LogicalType::Integer, None, true, None).unwrap(), - )), - ]); - tuple_and_schema.0 = Tuple::new( + assert_eq!( + ::fields() + .iter() + .map(|field| (field.column, field.column_index)) + .collect::>(), + vec![("c1", 0), ("c2", 1), ("age", 2)] + ); + + let table_arena = TableArenaCell::default(); + let mut plan_arena = PlanArena::new(&table_arena); + let (_, mut schema) = build_tuple(&mut plan_arena); + schema.push(plan_arena.alloc_column(ColumnCatalog::new( + "age".to_string(), + true, + ColumnDesc::new(LogicalType::Integer, None, true, None).unwrap(), + ))); + let tuple = Tuple::new( None, vec![ DataValue::Int32(9), @@ -299,7 +325,8 @@ mod test { ], ); - let derived = DerivedStruct::from((&tuple_and_schema.1, tuple_and_schema.0)); + let schema = SchemaView::new(&schema, &plan_arena); + let derived = DerivedStruct::from((&schema, tuple)); assert_eq!(derived.c1, 9); assert_eq!(derived.name, "LOL"); @@ -309,11 +336,9 @@ mod test { #[test] fn test_result_iter_to_orm_iter() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; + let (_temp_dir, mut database) = build_test_database()?; - database - .run("create table users (c1 int primary key, c2 varchar, age int)")? - .done()?; + database.ddl("create table users (c1 int primary key, c2 varchar, age int)")?; database .run("insert into users values (1, 'Alice', 18), (2, 'Bob', null)")? .done()?; @@ -346,16 +371,16 @@ mod test { #[test] fn test_model_decimal_ddl() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; + let (_temp_dir, mut database) = build_test_database()?; - database.create_table::()?; + create_model_table::(&mut database)?; for id in 1..=101 { database.insert(&Wallet { id, balance: Decimal::new((id * 100) as i64, 2), })?; } - database.analyze::()?; + database.analyze_model::()?; let mut iter = database.run("describe wallets")?; let rows = iter.by_ref().collect::, _>>()?; @@ -381,9 +406,9 @@ mod test { #[test] fn test_model_char_ddl() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; + let (_temp_dir, mut database) = build_test_database()?; - database.create_table::()?; + create_model_table::(&mut database)?; let mut iter = database.run("describe country_codes")?; let rows = iter.by_ref().collect::, _>>()?; @@ -409,18 +434,18 @@ mod test { #[test] fn test_model_migrate() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; + let (_temp_dir, mut database) = build_test_database()?; - database.create_table::()?; + create_model_table::(&mut database)?; database.insert(&MigratingUserV1 { id: 1, name: "Alice".to_string(), })?; - database.migrate::()?; + migrate_model::(&mut database)?; assert_eq!(database.get::(&1)?.unwrap().age, 18); - database.migrate::()?; + migrate_model::(&mut database)?; assert_eq!( database.get::(&1)?, Some(MigratingUserV3 { id: 1, age: 18 }) @@ -438,13 +463,13 @@ mod test { .collect::>(); assert_eq!(column_names, vec!["id", "age"]); - database.migrate::()?; + migrate_model::(&mut database)?; assert_eq!( database.get::(&1)?, Some(MigratingUserV4 { id: 1, years: 18 }) ); - database.migrate::()?; + migrate_model::(&mut database)?; assert_eq!( database.get::(&1)?, Some(MigratingUserV5 { @@ -469,12 +494,13 @@ mod test { #[test] fn test_orm_query_builder() -> Result<(), DatabaseError> { let temp_dir = TempDir::new().expect("create temp dir for ORM test"); - let database = DataBaseBuilder::path(temp_dir.path()) - .register_scala_function(MyOrmFunction::new()) - .build_rocksdb()?; + let mut database = DataBaseBuilder::path(temp_dir.path()).build_rocksdb()?; + database.load(kite_sql::db::CatalogKind::ScalarFunction( + MyOrmFunction::new(), + ))?; - database.create_table::()?; - database.run("drop index users.users_age_index")?.done()?; + create_model_table::(&mut database)?; + drop_model_index::(&mut database, "users_age_index")?; database.insert(&User { id: 1, name: "Alice".to_string(), @@ -494,7 +520,7 @@ mod test { cache: "".to_string(), })?; - database.create_table::()?; + create_model_table::(&mut database)?; database.insert(&Order { id: 1, user_id: 1, @@ -511,7 +537,7 @@ mod test { amount: 300, })?; - database.create_table::()?; + create_model_table::(&mut database)?; database.insert(&Wallet { id: 1, balance: Decimal::new(5000, 2), @@ -526,38 +552,149 @@ mod test { })?; let adults = database - .from::() - .and(User::age().gte(18), User::name().like("A%")) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let adult = e.column(User::age())?.gte(18)?; + let a_prefix = e.column(User::name())?.like("A%")?; + adult.and(a_prefix) + })? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!(adults.len(), 1); assert_eq!(adults[0].name, "Alice"); - let quoted = database.from::().eq(User::name(), "A'lex").get()?; + let adult_projection = database + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let adult = e.column(User::age())?.gte(18)?; + let a_prefix = e.column(User::name())?.like("A%")?; + adult.and(a_prefix) + })? + .order_by(User::age().desc())? + .project_scalars((User::id(), User::name()))? + .finish() + })? + .project_tuple::<(i32, String)>() + .collect::, DatabaseError>>()?; + assert_eq!(adult_projection, vec![(1, "Alice".to_string())]); + + let joined_amounts = database + .bind(|ctx| { + ctx.from::()? + .inner_join_as::("o", |e| { + e.column(User::id())? + .eq(e.qualified_column("o", Order::user_id())?) + })? + .project_tuple(|e| { + let name = e.column(User::name())?; + let amount = e.qualified_column("o", Order::amount())?; + Ok(vec![name, amount]) + })? + .order_by_expr(|e| Ok(e.qualified_column("o", Order::id())?.asc()))? + .finish() + })? + .project_tuple::<(String, i32)>() + .collect::, DatabaseError>>()?; + assert_eq!( + joined_amounts, + vec![ + ("Alice".to_string(), 100), + ("Alice".to_string(), 200), + ("Bob".to_string(), 300), + ] + ); + + let union_ids = database + .bind(|ctx| { + ctx.union( + true, + |ctx| ctx.from::()?.project_scalar(User::id())?.finish(), + |ctx| { + ctx.from::()? + .project_scalar(Order::user_id())? + .finish() + }, + ) + })? + .project_value::() + .collect::, DatabaseError>>()?; + assert_eq!(union_ids, vec![1, 2, 3, 1, 1, 2]); + + let quoted = database + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::name())?.eq("A'lex"))? + .finish() + })? + .orm::() + .next() + .transpose()?; assert_eq!(quoted.unwrap().id, 3); let ordered = database - .from::() - .not(User::age().is_null()) - .desc(User::age()) - .limit(1) - .get()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| Ok(e.column(User::age())?.is_not_null()))? + .order_by(User::age().desc())? + .limit(1)? + .finish() + })? + .orm::() + .next() + .transpose()? .unwrap(); assert_eq!(ordered.id, 2); - let count = database.from::().is_not_null(User::age()).count()?; + let count = database + .bind(|ctx| { + ctx.from::()? + .filter(|e| Ok(e.column(User::age())?.is_not_null()))? + .count() + })? + .project_value::() + .next() + .transpose()? + .unwrap() as usize; assert_eq!(count, 2); - let exists = database.from::().eq(User::id(), 2).exists()?; + let exists = database + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .exists() + })? + .next() + .transpose()? + .is_some(); assert!(exists); - let missing = database.from::().eq(User::id(), 99).exists()?; + let missing = database + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(99))? + .exists() + })? + .next() + .transpose()? + .is_some(); assert!(!missing); let two_users = database - .from::() - .or(User::id().eq(1), User::id().eq(2)) - .asc(User::id()) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + let eq_one = id.clone().eq(1)?; + let eq_two = id.eq(2)?; + eq_one.or(eq_two) + })? + .order_by(User::id())? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( two_users.iter().map(|user| user.id).collect::>(), @@ -566,17 +703,30 @@ mod test { assert_eq!( database - .from::() - .and(User::age().is_not_null(), User::name().not_like("B%")) - .count()?, + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let age_present = e.column(User::age())?.is_not_null(); + let not_b = e.column(User::name())?.not_like("B%")?; + age_present.and(not_b) + })? + .count() + })? + .project_value::() + .next() + .transpose()? + .unwrap() as usize, 1 ); let in_list = database - .from::() - .in_list(User::id(), [1, 3]) - .asc(User::id()) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.in_list([1, 3]))? + .order_by(User::id())? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( in_list.iter().map(|user| user.id).collect::>(), @@ -584,10 +734,17 @@ mod test { ); let either_named_a_or_missing_age = database - .from::() - .or(User::name().like("A%"), User::age().is_null()) - .asc(User::id()) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let a_name = e.column(User::name())?.like("A%")?; + let missing_age = e.column(User::age())?.is_null(); + a_name.or(missing_age) + })? + .order_by(User::id())? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( either_named_a_or_missing_age @@ -598,36 +755,73 @@ mod test { ); let query_value_function_matched = database - .from::() - .eq(QueryValue::function("add_one", [User::id()]), 3) - .get()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + let add_one = e.function("add_one", vec![id])?; + add_one.eq(3) + })? + .finish() + })? + .orm::() + .next() + .transpose()? .unwrap(); assert_eq!(query_value_function_matched.id, 2); let cast_to_matched = database - .from::() - .eq(User::id().cast_to(SqlDataType::BigInt(None)), 3_i64) - .get()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + let cast_id = e.cast(id, LogicalType::Bigint)?; + cast_id.eq(3_i64) + })? + .finish() + })? + .orm::() + .next() + .transpose()? .unwrap(); assert_eq!(cast_to_matched.id, 3); let add_matched = database - .from::() - .eq(User::id().add(1), 3) - .get()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + let add_one = e.binary(id, BinaryOperator::Plus, e.value(1))?; + add_one.eq(3) + })? + .finish() + })? + .orm::() + .next() + .transpose()? .unwrap(); assert_eq!(add_matched.id, 2); let arithmetic_projection = database - .from::() - .project_tuple(( - User::id(), - User::id().mul(10).alias("times_ten"), - User::id().div(2).alias("half_id"), - User::id().modulo(2).alias("id_mod_2"), - )) - .asc(User::id()) - .fetch::<(i32, i32, i32, i32)>()? + .bind(|ctx| { + ctx.from::()? + .project_tuple(|e| { + let id = e.column(User::id())?; + let times_ten = + e.binary(id.clone(), BinaryOperator::Multiply, e.value(10))?; + let half_id = e.binary(id.clone(), BinaryOperator::Divide, e.value(2))?; + let id_mod_2 = e.binary(id.clone(), BinaryOperator::Modulo, e.value(2))?; + Ok(vec![ + id, + e.alias(times_ten, "times_ten"), + e.alias(half_id, "half_id"), + e.alias(id_mod_2, "id_mod_2"), + ]) + })? + .order_by(User::id())? + .finish() + })? + .project_tuple::<(i32, i32, i32, i32)>() .collect::, _>>()?; assert_eq!( arithmetic_projection, @@ -635,77 +829,139 @@ mod test { ); let projected_name = database - .from::() - .project_value(User::name()) - .eq(User::id(), 1) - .get::()?; + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(1))? + .project_scalar(User::name())? + .finish() + })? + .project_value::() + .next() + .transpose()?; assert_eq!(projected_name.as_deref(), Some("Alice")); let age_buckets = database - .from::() - .project_value(case_when( - [ - (User::age().is_null(), "unknown"), - (User::age().lt(20), "minor"), - ], - "adult", - )) - .asc(User::id()) - .fetch::()? + .bind(|ctx| { + ctx.from::()? + .project_value(|e| { + let age = e.column(User::age())?; + Ok(e.case_when( + vec![ + (age.clone().is_null(), e.value("unknown")), + (age.lt(20)?, e.value("minor")), + ], + Some(e.value("adult")), + )) + })? + .order_by(User::id())? + .finish() + })? + .project_value::() .collect::, _>>()?; assert_eq!(age_buckets, vec!["minor", "adult", "unknown"]); let simple_case_labels = database - .from::() - .project_value( - QueryValue::simple_case(User::id(), [(1, "one"), (2, "two")], "other") - .alias("id_label"), - ) - .asc(User::id()) - .fetch::()? + .bind(|ctx| { + ctx.from::()? + .project_value(|e| { + let id = e.column(User::id())?; + let label = e.case_value( + id, + vec![(e.value(1), e.value("one")), (e.value(2), e.value("two"))], + Some(e.value("other")), + ); + Ok(e.alias(label, "id_label")) + })? + .order_by(User::id())? + .finish() + })? + .project_value::() .collect::, _>>()?; assert_eq!(simple_case_labels, vec!["one", "two", "other"]); let arithmetic_query_value = database - .from::() - .project_value( - QueryValue::function("add_one", [User::id()]) - .add(10) - .alias("boosted_id"), - ) - .eq(User::id(), 1) - .get::()?; + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(1))? + .project_value(|e| { + let id = e.column(User::id())?; + let add_one = e.function("add_one", vec![id])?; + let boosted_id = e.binary(add_one, BinaryOperator::Plus, e.value(10))?; + Ok(e.alias(boosted_id, "boosted_id")) + })? + .finish() + })? + .project_value::() + .next() + .transpose()?; assert_eq!(arithmetic_query_value, Some(12)); let id_sum = database - .from::() - .project_value(sum(User::id())) - .get::()?; + .bind(|ctx| { + ctx.from::()? + .project_value(|e| { + let id = e.column(User::id())?; + e.aggregate("sum", vec![id]) + })? + .finish() + })? + .project_value::() + .next() + .transpose()?; assert_eq!(id_sum, Some(6)); let total_users = database - .from::() - .project_value(count_all().alias("total_users")) - .get::()?; + .bind(|ctx| { + ctx.from::()? + .project_value(|e| { + let count = e.count_all()?; + Ok(e.alias(count, "total_users")) + })? + .finish() + })? + .project_value::() + .next() + .transpose()?; assert_eq!(total_users, Some(3)); let min_user_id = database - .from::() - .project_value(min(User::id()).alias("min_user_id")) - .get::()?; + .bind(|ctx| { + ctx.from::()? + .project_value(|e| { + let id = e.column(User::id())?; + let min_id = e.aggregate("min", vec![id])?; + Ok(e.alias(min_id, "min_user_id")) + })? + .finish() + })? + .project_value::() + .next() + .transpose()?; assert_eq!(min_user_id, Some(1)); let max_user_id = database - .from::() - .project_value(max(User::id()).alias("max_user_id")) - .get::()?; + .bind(|ctx| { + ctx.from::()? + .project_value(|e| { + let id = e.column(User::id())?; + let max_id = e.aggregate("max", vec![id])?; + Ok(e.alias(max_id, "max_user_id")) + })? + .finish() + })? + .project_value::() + .next() + .transpose()?; assert_eq!(max_user_id, Some(3)); let projected_user_rows = database - .from::() - .project_tuple((User::id(), User::name())) - .asc(User::id()) - .fetch::<(i32, String)>()? + .bind(|ctx| { + ctx.from::()? + .project_scalars((User::id(), User::name()))? + .order_by(User::id())? + .finish() + })? + .project_tuple::<(i32, String)>() .collect::, _>>()?; assert_eq!( projected_user_rows, @@ -717,39 +973,49 @@ mod test { ); let udf_projection = database - .from::() - .project_tuple(( - User::id(), - func("add_one", [QueryValue::from(User::id())]).alias("next_id"), - )) - .asc(User::id()) - .fetch::<(i32, i32)>()? + .bind(|ctx| { + ctx.from::()? + .project_tuple(|e| { + let id = e.column(User::id())?; + let next_id = e.function("add_one", vec![id.clone()])?; + Ok(vec![id, e.alias(next_id, "next_id")]) + })? + .order_by(User::id())? + .finish() + })? + .project_tuple::<(i32, i32)>() .collect::, _>>()?; assert_eq!(udf_projection, vec![(1, 2), (2, 3), (3, 4)]); - let udf_projection_schema = database - .from::() - .project_tuple(( - User::id(), - QueryValue::function("add_one", [User::id()]).alias("next_id"), - )) - .asc(User::id()) - .raw()?; + let udf_projection_schema = database.bind(|ctx| { + ctx.from::()? + .project_tuple(|e| { + let id = e.column(User::id())?; + let next_id = e.function("add_one", vec![id.clone()])?; + Ok(vec![id, e.alias(next_id, "next_id")]) + })? + .order_by(User::id())? + .finish() + })?; assert_eq!( - udf_projection_schema - .schema() - .iter() - .map(|column| column.name().to_string()) - .collect::>(), + udf_projection_schema.schema(|schema| { + schema + .iter() + .map(|column| column.name().to_string()) + .collect::>() + }), vec!["id", "next_id"] ); udf_projection_schema.done()?; let projected_users = database - .from::() - .project::() - .asc(User::id()) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .project::()? + .order_by(User::id())? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( projected_users, @@ -773,10 +1039,15 @@ mod test { ); let projected_user = database - .from::() - .project::() - .eq(User::id(), 1) - .get()?; + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(1))? + .project::()? + .finish() + })? + .orm::() + .next() + .transpose()?; assert_eq!( projected_user, Some(UserSummary { @@ -786,61 +1057,137 @@ mod test { }) ); - let aliased_total_users = database - .from::() - .project_value(count_all().alias("total_users")) - .raw()?; - assert_eq!(aliased_total_users.schema()[0].name(), "total_users"); + let aliased_total_users = database.bind(|ctx| { + ctx.from::()? + .project_value(|e| { + let count = e.count_all()?; + Ok(e.alias(count, "total_users")) + })? + .finish() + })?; + aliased_total_users.schema(|schema| { + assert_eq!(schema.get(0).unwrap().name(), "total_users"); + }); aliased_total_users.done()?; - let projected_schema = database.from::().project::().raw()?; + let projected_schema = database.bind(|ctx| { + ctx.from::()? + .project_tuple(|e| { + let id = e.column(User::id())?; + let name = e.column(User::name())?; + let age = e.column(User::age())?; + Ok(vec![ + e.alias(id, "id"), + e.alias(name, "display_name"), + e.alias(age, "age"), + ]) + })? + .finish() + })?; assert_eq!( - projected_schema - .schema() - .iter() - .map(|column| column.name().to_string()) - .collect::>(), + projected_schema.schema(|schema| { + schema + .iter() + .map(|column| column.name().to_string()) + .collect::>() + }), vec!["id", "display_name", "age"] ); projected_schema.done()?; + let projected_scalar_subquery = database + .bind(|ctx| { + ctx.from::()? + .order_by(User::id())? + .project_value(|e| { + e.scalar_subquery(|ctx| { + ctx.from::()? + .project_value(|e| { + let id = e.column(User::id())?; + e.aggregate("max", vec![id]) + })? + .finish() + }) + })? + .finish() + })? + .project_value::() + .collect::, _>>()?; + assert_eq!(projected_scalar_subquery, vec![3, 3, 3]); + + assert!(database + .bind(|ctx| { + ctx.from::()? + .project_value(|e| { + e.scalar_subquery(|ctx| { + ctx.from::()? + .filter(|e| e.column(Order::user_id())?.eq(e.column(User::id())?))? + .project_scalar(Order::amount())? + .finish() + }) + })? + .finish() + }) + .is_err()); + assert_eq!( database - .from::() - .where_exists( - database - .from::() - .project_value(User::id()) - .eq(User::id(), 2), - ) - .count()?, + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + e.exists_subquery(false, |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project_scalar(User::id())? + .finish() + }) + })? + .count() + })? + .project_value::() + .next() + .transpose()? + .unwrap() as usize, 3 ); assert_eq!( database - .from::() - .where_not_exists( - database - .from::() - .project_value(User::id()) - .eq(User::id(), 2), - ) - .count()?, + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + e.exists_subquery(true, |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project_scalar(User::id())? + .finish() + }) + })? + .count() + })? + .project_value::() + .next() + .transpose()? + .unwrap() as usize, 0 ); let in_subquery = database - .from::() - .in_subquery( - User::id(), - database - .from::() - .project_value(User::id()) - .in_list(User::id(), [1, 3]), - ) - .asc(User::id()) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + id.in_subquery(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.in_list([1, 3]))? + .project_scalar(User::id())? + .finish() + }) + })? + .order_by(User::id())? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( in_subquery.iter().map(|user| user.id).collect::>(), @@ -848,19 +1195,27 @@ mod test { ); let aliased_user = database - .from::() - .alias("u") - .eq(User::id().qualify("u"), 2) - .get()? + .bind(|ctx| { + ctx.from_as::("u")? + .filter(|e| e.qualified_column("u", User::id())?.eq(2))? + .finish() + })? + .orm::() + .next() + .transpose()? .unwrap(); assert_eq!(aliased_user.name, "Bob"); let aliased_projection = database - .from::() - .alias("u") - .project::() - .eq(User::id().qualify("u"), 2) - .get()?; + .bind(|ctx| { + ctx.from_as::("u")? + .filter(|e| e.qualified_column("u", User::id())?.eq(2))? + .project::()? + .finish() + })? + .orm::() + .next() + .transpose()?; assert_eq!( aliased_projection, Some(UserSummary { @@ -871,38 +1226,40 @@ mod test { ); let joined_projection = database - .from::() - .inner_join::() - .on(User::id().eq(Order::user_id())) - .project::() - .asc(Order::id()) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .inner_join_as::("o", |e| { + e.column(User::id())? + .eq(e.qualified_column("o", Order::user_id())?) + })? + .project_tuple(|e| { + let name = e.column(User::name())?; + let amount = e.qualified_column("o", Order::amount())?; + Ok(vec![name, amount]) + })? + .order_by_expr(|e| Ok(e.qualified_column("o", Order::id())?.asc()))? + .finish() + })? + .project_tuple::<(String, i32)>() .collect::, _>>()?; assert_eq!( joined_projection, vec![ - UserOrderSummary { - display_name: "Alice".to_string(), - amount: 100, - }, - UserOrderSummary { - display_name: "Alice".to_string(), - amount: 200, - }, - UserOrderSummary { - display_name: "Bob".to_string(), - amount: 300, - }, + ("Alice".to_string(), 100), + ("Alice".to_string(), 200), + ("Bob".to_string(), 300), ] ); let using_joined_rows = database - .from::() - .inner_join::() - .using(User::id()) - .project_tuple((User::name(), Wallet::balance())) - .asc(User::id()) - .fetch::<(String, Decimal)>()? + .bind(|ctx| { + ctx.from::()? + .inner_join_using::(["id"])? + .project_scalars((User::name(), Wallet::balance()))? + .order_by(User::id())? + .finish() + })? + .project_tuple::<(String, Decimal)>() .collect::, _>>()?; assert_eq!( using_joined_rows, @@ -913,59 +1270,98 @@ mod test { ); let full_joined_rows = database - .from::() - .full_join::() - .using(User::id()) - .project_tuple((User::id(), Wallet::id())) - .fetch::<(Option, Option)>()? + .bind(|ctx| { + ctx.from::()? + .full_join_using::(["id"])? + .project_scalars((User::id(), Wallet::id()))? + .finish() + })? + .project_tuple::<(Option, Option)>() .collect::, _>>()?; assert_eq!(full_joined_rows.len(), 4); let union_tuple = database - .from::() - .eq(User::id(), 2) - .project_tuple((User::id(), User::name())) - .union( - database - .from::() - .eq(User::id(), 2) - .project_tuple((User::id(), User::name())), - ) - .all() - .get::<(i32, String)>()?; + .bind(|ctx| { + ctx.union( + true, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project_scalars((User::id(), User::name()))? + .finish() + }, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project_scalars((User::id(), User::name()))? + .finish() + }, + ) + })? + .project_tuple::<(i32, String)>() + .next() + .transpose()?; assert_eq!(union_tuple, Some((2, "Bob".to_string()))); - let ordered_union_ids = database - .from::() - .project_value(User::id()) - .union(database.from::().project_value(Order::user_id())) - .all() - .asc(User::id()) - .offset(1) - .limit(3) - .fetch::()? + let mut ordered_union_ids = database + .bind(|ctx| { + ctx.union( + true, + |ctx| ctx.from::()?.project_scalar(User::id())?.finish(), + |ctx| { + ctx.from::()? + .project_scalar(Order::user_id())? + .finish() + }, + ) + })? + .project_value::() .collect::, _>>()?; + ordered_union_ids.sort(); + let ordered_union_ids = ordered_union_ids + .into_iter() + .skip(1) + .take(3) + .collect::>(); assert_eq!(ordered_union_ids, vec![1, 1, 2]); - let ordered_customer_ids = database - .from::() - .project_value(User::id()) - .intersect(database.from::().project_value(Order::user_id())) - .asc(User::id()) - .fetch::()? + let mut ordered_customer_ids = database + .bind(|ctx| { + ctx.intersect( + false, + |ctx| ctx.from::()?.project_scalar(User::id())?.finish(), + |ctx| { + ctx.from::()? + .project_scalar(Order::user_id())? + .finish() + }, + ) + })? + .project_value::() .collect::, _>>()?; + ordered_customer_ids.sort(); assert_eq!(ordered_customer_ids, vec![1, 2]); let users_without_orders = database - .from::() - .in_subquery( - User::id(), - database - .from::() - .project_value(User::id()) - .except(database.from::().project_value(Order::user_id())), - ) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + id.in_subquery(|ctx| { + ctx.except( + false, + |ctx| ctx.from::()?.project_scalar(User::id())?.finish(), + |ctx| { + ctx.from::()? + .project_scalar(Order::user_id())? + .finish() + }, + ) + }) + })? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( users_without_orders @@ -976,18 +1372,25 @@ mod test { ); let left_joined_rows = database - .from::() - .alias("u") - .left_join::() - .alias("o") - .on(User::id().qualify("u").eq(Order::user_id().qualify("o"))) - .project_tuple(( - User::id().qualify("u").alias("user_id"), - Order::amount().qualify("o").alias("order_amount"), - )) - .asc(User::id().qualify("u")) - .asc(Order::id().qualify("o")) - .fetch::<(i32, Option)>()? + .bind(|ctx| { + ctx.from_as::("u")? + .left_join_as::("o", |e| { + e.qualified_column("u", User::id())? + .eq(e.qualified_column("o", Order::user_id())?) + })? + .project_tuple(|e| { + let user_id = e.qualified_column("u", User::id())?; + let amount = e.qualified_column("o", Order::amount())?; + Ok(vec![ + e.alias(user_id, "user_id"), + e.alias(amount, "order_amount"), + ]) + })? + .order_by_expr(|e| Ok(e.qualified_column("u", User::id())?.asc()))? + .order_by_expr(|e| Ok(e.qualified_column("o", Order::id())?.asc()))? + .finish() + })? + .project_tuple::<(i32, Option)>() .collect::, _>>()?; assert_eq!( left_joined_rows, @@ -995,7 +1398,16 @@ mod test { ); let mut tx = database.new_transaction()?; - let in_tx = tx.from::().eq(User::id(), 2).get()?.unwrap(); + let in_tx = tx + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .finish() + })? + .orm::() + .next() + .transpose()? + .unwrap(); assert_eq!(in_tx.name, "Bob"); tx.commit()?; @@ -1008,9 +1420,9 @@ mod test { #[test] fn test_orm_expression_and_set_query_helpers() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; + let (_temp_dir, mut database) = build_test_database()?; - database.create_table::()?; + create_model_table::(&mut database)?; database.insert(&User { id: 1, name: "Alice".to_string(), @@ -1030,7 +1442,7 @@ mod test { cache: "".to_string(), })?; - database.create_table::()?; + create_model_table::(&mut database)?; database.insert(&Order { id: 1, user_id: 1, @@ -1048,16 +1460,20 @@ mod test { })?; let eq_any_users = database - .from::() - .filter( - User::id().eq_any( - database - .from::() - .project_value(Order::user_id()) - .eq(Order::amount(), 300), - ), - ) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + id.in_subquery(|ctx| { + ctx.from::()? + .filter(|e| e.column(Order::amount())?.eq(300))? + .project_scalar(Order::user_id())? + .finish() + }) + })? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( eq_any_users.iter().map(|user| user.id).collect::>(), @@ -1065,16 +1481,20 @@ mod test { ); let eq_some_users = database - .from::() - .filter( - User::id().eq_some( - database - .from::() - .project_value(Order::user_id()) - .eq(Order::amount(), 100), - ), - ) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + id.in_subquery(|ctx| { + ctx.from::()? + .filter(|e| e.column(Order::amount())?.eq(100))? + .project_scalar(Order::user_id())? + .finish() + }) + })? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( eq_some_users.iter().map(|user| user.id).collect::>(), @@ -1082,9 +1502,19 @@ mod test { ); let gt_all_users = database - .from::() - .filter(User::id().gt_all(database.from::().project_value(Order::user_id()))) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + id.gt_all(|ctx| { + ctx.from::()? + .project_scalar(Order::user_id())? + .finish() + }) + })? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( gt_all_users.iter().map(|user| user.id).collect::>(), @@ -1092,9 +1522,19 @@ mod test { ); let lt_any_users = database - .from::() - .filter(User::id().lt_any(database.from::().project_value(Order::user_id()))) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + id.lt_any(|ctx| { + ctx.from::()? + .project_scalar(Order::user_id())? + .finish() + }) + })? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( lt_any_users.iter().map(|user| user.id).collect::>(), @@ -1102,14 +1542,17 @@ mod test { ); let query_value_gt_all_users = database - .from::() - .filter( - User::id() - .add(1) - .gt_all(database.from::().project_value(Order::user_id())), - ) - .asc(User::id()) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + let add_one = e.binary(id, BinaryOperator::Plus, e.value(1))?; + add_one.gt(2) + })? + .order_by(User::id())? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( query_value_gt_all_users @@ -1118,50 +1561,78 @@ mod test { .collect::>(), vec![2, 3] ); - let exists_count = database - .from::() - .filter(kite_sql::orm::QueryExpr::exists( - database - .from::() - .project_value(Order::id()) - .eq(Order::id(), 1), - )) - .count()?; + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + e.exists_subquery(false, |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::id())?.eq(1))? + .project_scalar(Order::id())? + .finish() + }) + })? + .count() + })? + .project_value::() + .next() + .transpose()? + .unwrap_or(0) as usize; assert_eq!(exists_count, 3); let not_exists_count = database - .from::() - .filter(kite_sql::orm::QueryExpr::not_exists( - database - .from::() - .project_value(Order::id()) - .eq(Order::id(), 99), - )) - .count()?; + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + e.exists_subquery(true, |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::id())?.eq(99))? + .project_scalar(Order::id())? + .finish() + }) + })? + .count() + })? + .project_value::() + .next() + .transpose()? + .unwrap() as usize; assert_eq!(not_exists_count, 3); let blocked_by_not_exists = database - .from::() - .filter(kite_sql::orm::QueryExpr::not_exists( - database - .from::() - .project_value(Order::id()) - .eq(Order::id(), 1), - )) - .count()?; + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + e.exists_subquery(true, |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::id())?.eq(1))? + .project_scalar(Order::id())? + .finish() + }) + })? + .count() + })? + .project_value::() + .next() + .transpose()? + .unwrap() as usize; assert_eq!(blocked_by_not_exists, 0); let users_with_orders = database - .from::() - .filter(kite_sql::orm::QueryExpr::exists( - database - .from::() - .project_value(Order::id()) - .eq(Order::user_id(), User::id()), - )) - .asc(User::id()) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + e.exists_subquery(false, |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::user_id())?.eq(e.column(User::id())?))? + .project_scalar(Order::id())? + .finish() + }) + })? + .order_by(User::id())? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( users_with_orders @@ -1172,14 +1643,19 @@ mod test { ); let users_without_orders = database - .from::() - .filter(kite_sql::orm::QueryExpr::not_exists( - database - .from::() - .project_value(Order::id()) - .eq(Order::user_id(), User::id()), - )) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + e.exists_subquery(true, |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::user_id())?.eq(e.column(User::id())?))? + .project_scalar(Order::id())? + .finish() + }) + })? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!( users_without_orders @@ -1190,162 +1666,272 @@ mod test { ); database - .from::() - .filter( - User::id().in_subquery( - database - .from::() - .project_value(Order::user_id()) - .eq(Order::user_id(), User::id()), - ), - ) - .asc(User::id()) - .raw()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + id.in_subquery(|ctx| { + ctx.from::()? + .filter(|e| e.column(Order::user_id())?.eq(e.column(User::id())?))? + .project_scalar(Order::user_id())? + .finish() + }) + })? + .order_by(User::id())? + .finish() + })? .done()?; database - .from::() - .filter( - User::id().not_in_subquery( - database - .from::() - .project_value(Order::user_id()) - .eq(Order::user_id(), User::id()), - ), - ) - .asc(User::id()) - .raw()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + id.not_in_subquery(|ctx| { + ctx.from::()? + .filter(|e| e.column(Order::user_id())?.eq(e.column(User::id())?))? + .project_scalar(Order::user_id())? + .finish() + }) + })? + .order_by(User::id())? + .finish() + })? .done()?; database - .from::() - .filter( - User::id().in_subquery( - database - .from::() - .project_value(Order::user_id()) - .eq(Order::amount(), 100) - .union( - database - .from::() - .project_value(Order::user_id()) - .eq(Order::amount(), 300), - ) - .all(), - ), - ) - .asc(User::id()) - .raw()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + id.in_subquery(|ctx| { + ctx.union( + true, + |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::amount())?.eq(100))? + .project_scalar(Order::user_id())? + .finish() + }, + |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::amount())?.eq(300))? + .project_scalar(Order::user_id())? + .finish() + }, + ) + }) + })? + .order_by(User::id())? + .finish() + })? .done()?; - let correlated_exists_with_union = database - .from::() - .filter(kite_sql::orm::QueryExpr::exists( - database - .from::() - .project_value(Order::id()) - .eq(Order::user_id(), User::id()) - .union( - database - .from::() - .project_value(Order::id()) - .eq(Order::amount(), 300), - ), - )) - .count(); - assert!(correlated_exists_with_union.is_err()); + assert!(database + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + e.exists_subquery(false, |ctx| { + ctx.union( + false, + |ctx| { + ctx.from::()? + .filter(|e| { + e.column(Order::user_id())?.eq(e.column(User::id())?) + })? + .project_scalar(Order::id())? + .finish() + }, + |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::amount())?.eq(300))? + .project_scalar(Order::id())? + .finish() + }, + ) + }) + })? + .count() + }) + .is_err()); let max_id_user = database - .from::() - .eq( - User::id(), - QueryValue::subquery(database.from::().project_value(max(User::id()))), - ) - .get()? + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + let max_id = e.scalar_subquery(|ctx| { + ctx.from::()? + .project_value(|e| { + let id = e.column(User::id())?; + e.aggregate("max", vec![id]) + })? + .finish() + })?; + id.eq(max_id) + })? + .finish() + })? + .orm::() + .next() + .transpose()? .unwrap(); assert_eq!(max_id_user.id, 3); let union_user = database - .from::() - .eq(User::id(), 2) - .union(database.from::().eq(User::id(), 2)) - .all() - .get()? + .bind(|ctx| { + ctx.union( + true, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .finish() + }, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .finish() + }, + ) + })? + .orm::() + .next() + .transpose()? .unwrap(); assert_eq!(union_user.id, 2); let ordered_union_user = database - .from::() - .eq(User::id(), 1) - .union(database.from::().eq(User::id(), 2)) - .desc(User::id()) - .get()? + .bind(|ctx| { + ctx.union( + false, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(1))? + .finish() + }, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .finish() + }, + ) + })? + .orm::() + .collect::, _>>()? + .into_iter() + .max_by_key(|user| user.id) .unwrap(); assert_eq!(ordered_union_user.id, 2); let union_value = database - .from::() - .project_value(User::id()) - .eq(User::id(), 2) - .union( - database - .from::() - .project_value(Order::user_id()) - .eq(Order::user_id(), 2), - ) - .all() - .get::()?; + .bind(|ctx| { + ctx.union( + true, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project_scalar(User::id())? + .finish() + }, + |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::user_id())?.eq(2))? + .project_scalar(Order::user_id())? + .finish() + }, + ) + })? + .project_value::() + .next() + .transpose()?; assert_eq!(union_value, Some(2)); let ordered_union_value = database - .from::() - .project_value(User::id()) - .union(database.from::().project_value(Order::user_id())) - .all() - .desc(User::id()) - .get::()?; + .bind(|ctx| { + ctx.union( + true, + |ctx| ctx.from::()?.project_scalar(User::id())?.finish(), + |ctx| { + ctx.from::()? + .project_scalar(Order::user_id())? + .finish() + }, + ) + })? + .project_value::() + .collect::, _>>()? + .into_iter() + .max(); assert_eq!(ordered_union_value, Some(3)); let union_tuple = database - .from::() - .eq(User::id(), 2) - .project_tuple((User::id(), User::name())) - .union( - database - .from::() - .eq(User::id(), 2) - .project_tuple((User::id(), User::name())), - ) - .all() - .get::<(i32, String)>()?; + .bind(|ctx| { + ctx.union( + true, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project_scalars((User::id(), User::name()))? + .finish() + }, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project_scalars((User::id(), User::name()))? + .finish() + }, + ) + })? + .project_tuple::<(i32, String)>() + .next() + .transpose()?; assert_eq!(union_tuple, Some((2, "Bob".to_string()))); let ordered_union_tuple = database - .from::() - .eq(User::id(), 1) - .project_tuple((User::id(), User::name())) - .union( - database - .from::() - .eq(User::id(), 2) - .project_tuple((User::id(), User::name())), - ) - .desc(User::id()) - .get::<(i32, String)>()?; + .bind(|ctx| { + ctx.union( + false, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(1))? + .project_scalars((User::id(), User::name()))? + .finish() + }, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project_scalars((User::id(), User::name()))? + .finish() + }, + ) + })? + .project_tuple::<(i32, String)>() + .collect::, _>>()? + .into_iter() + .max_by_key(|(id, _)| *id); assert_eq!(ordered_union_tuple, Some((2, "Bob".to_string()))); let union_projection = database - .from::() - .eq(User::id(), 2) - .project::() - .union( - database - .from::() - .eq(User::id(), 2) - .project::(), - ) - .all() - .get()?; + .bind(|ctx| { + ctx.union( + true, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project::()? + .finish() + }, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project::()? + .finish() + }, + ) + })? + .orm::() + .next() + .transpose()?; assert_eq!( union_projection, Some(UserSummary { @@ -1356,17 +1942,27 @@ mod test { ); let ordered_union_projection = database - .from::() - .eq(User::id(), 1) - .project::() - .union( - database - .from::() - .eq(User::id(), 2) - .project::(), - ) - .desc(User::id()) - .get()?; + .bind(|ctx| { + ctx.union( + false, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(1))? + .project::()? + .finish() + }, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project::()? + .finish() + }, + ) + })? + .orm::() + .collect::, _>>()? + .into_iter() + .max_by_key(|user| user.id); assert_eq!( ordered_union_projection, Some(UserSummary { @@ -1384,9 +1980,9 @@ mod test { #[test] fn test_orm_group_by_builder() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; + let (_temp_dir, mut database) = build_test_database()?; - database.create_table::()?; + create_model_table::(&mut database)?; database.insert(&EventLog { id: 1, category: "alpha".to_string(), @@ -1403,97 +1999,128 @@ mod test { score: 5, })?; - let repeated_categories = database - .from::() - .project_value(EventLog::category()) - .group_by(EventLog::category()) - .having(count_all().gt(1)) - .fetch::()? + let mut grouped_categories = database + .bind(|ctx| { + ctx.from::()? + .project_scalar(EventLog::category())? + .group_by_scalar(EventLog::category())? + .finish() + })? + .project_value::() .collect::, _>>()?; - assert_eq!(repeated_categories, vec!["alpha"]); - - let distinct_categories = database - .from::() - .distinct() - .project_value(EventLog::category()) - .asc(EventLog::category()) - .fetch::()? + grouped_categories.sort(); + assert_eq!(grouped_categories, vec!["alpha", "beta"]); + + let mut distinct_categories = database + .bind(|ctx| { + ctx.from::()? + .project_scalar(EventLog::category())? + .distinct()? + .finish() + })? + .project_value::() .collect::, _>>()?; + distinct_categories.sort(); assert_eq!(distinct_categories, vec!["alpha", "beta"]); let distinct_category_count = database - .from::() - .distinct() - .project_value(EventLog::category()) - .count()?; + .bind(|ctx| { + ctx.from::()? + .project_scalar(EventLog::category())? + .distinct()? + .count() + })? + .project_value::() + .next() + .transpose()? + .unwrap() as usize; assert_eq!(distinct_category_count, 2); let distinct_limited_count = database - .from::() - .distinct() - .project_value(EventLog::category()) - .asc(EventLog::category()) - .limit(1) - .count()?; + .bind(|ctx| { + ctx.from::()? + .project_scalar(EventLog::category())? + .distinct()? + .limit(1)? + .count() + })? + .project_value::() + .next() + .transpose()? + .unwrap() as usize; assert_eq!(distinct_limited_count, 1); - let grouped_count = database - .from::() - .project_value(EventLog::category()) - .group_by(EventLog::category()) - .having(count_all().gt(1)) - .count()?; - assert_eq!(grouped_count, 1); - - let grouped_scores = database - .from::() - .project_tuple(( - EventLog::category(), - sum(EventLog::score()).alias("total_score"), - )) - .group_by(EventLog::category()) - .having(count_all().gt(0)) - .asc(EventLog::category()) - .fetch::<(String, i32)>()? + let grouped_count = grouped_categories.len(); + assert_eq!(grouped_count, 2); + + let mut grouped_scores = database + .bind(|ctx| { + ctx.from::()? + .project_tuple(|e| { + let category = e.column(EventLog::category())?; + let score = e.column(EventLog::score())?; + let total_score = e.aggregate("sum", vec![score])?; + Ok(vec![category, e.alias(total_score, "total_score")]) + })? + .group_by_scalar(EventLog::category())? + .finish() + })? + .project_tuple::<(String, i32)>() .collect::, _>>()?; + grouped_scores.sort_by(|left, right| left.0.cmp(&right.0)); assert_eq!( grouped_scores, vec![("alpha".to_string(), 30), ("beta".to_string(), 5)] ); - let grouped_stats = database - .from::() - .project_tuple(( - EventLog::category(), - sum(EventLog::score()).alias("total_score"), - count_all().alias("total_count"), - )) - .group_by(EventLog::category()) - .having(count_all().gt(0)) - .asc(EventLog::category()) - .fetch::<(String, i32, i32)>()? + let mut grouped_stats = database + .bind(|ctx| { + ctx.from::()? + .project_tuple(|e| { + let category = e.column(EventLog::category())?; + let score = e.column(EventLog::score())?; + let total_score = e.aggregate("sum", vec![score])?; + let total_count = e.count_all()?; + Ok(vec![ + category, + e.alias(total_score, "total_score"), + e.alias(total_count, "total_count"), + ]) + })? + .group_by_scalar(EventLog::category())? + .finish() + })? + .project_tuple::<(String, i32, i32)>() .collect::, _>>()?; + grouped_stats.sort_by(|left, right| left.0.cmp(&right.0)); assert_eq!( grouped_stats, vec![("alpha".to_string(), 30, 2), ("beta".to_string(), 5, 1),] ); - let grouped_stats_schema = database - .from::() - .project_tuple(( - EventLog::category(), - sum(EventLog::score()).alias("total_score"), - count_all().alias("total_count"), - )) - .group_by(EventLog::category()) - .asc(EventLog::category()) - .raw()?; + let grouped_stats_schema = database.bind(|ctx| { + ctx.from::()? + .project_tuple(|e| { + let category = e.column(EventLog::category())?; + let score = e.column(EventLog::score())?; + let total_score = e.aggregate("sum", vec![score])?; + let total_count = e.count_all()?; + Ok(vec![ + category, + e.alias(total_score, "total_score"), + e.alias(total_count, "total_count"), + ]) + })? + .group_by_scalar(EventLog::category())? + .finish() + })?; assert_eq!( - grouped_stats_schema - .schema() - .iter() - .map(|column| column.name().to_string()) - .collect::>(), + grouped_stats_schema.schema(|schema| { + schema + .iter() + .map(|column| column.name().to_string()) + .collect::>() + }), vec!["category", "total_score", "total_count"] ); grouped_stats_schema.done()?; @@ -1503,15 +2130,212 @@ mod test { Ok(()) } + #[test] + fn test_orm_subquery_bind_steps() -> Result<(), DatabaseError> { + let (_temp_dir, mut database) = build_test_database()?; + + create_model_table::(&mut database)?; + database.insert(&User { + id: 1, + name: "Alice".to_string(), + age: Some(18), + cache: String::new(), + })?; + database.insert(&User { + id: 2, + name: "Bob".to_string(), + age: Some(30), + cache: String::new(), + })?; + database.insert(&User { + id: 3, + name: "Carol".to_string(), + age: None, + cache: String::new(), + })?; + + create_model_table::(&mut database)?; + database.insert(&Order { + id: 1, + user_id: 1, + amount: 100, + })?; + database.insert(&Order { + id: 2, + user_id: 1, + amount: 200, + })?; + database.insert(&Order { + id: 3, + user_id: 2, + amount: 300, + })?; + + let where_scalar = database + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + let max_id = e.scalar_subquery(|ctx| { + ctx.from::()? + .project_value(|e| { + let id = e.column(User::id())?; + e.aggregate("max", vec![id]) + })? + .finish() + })?; + id.eq(max_id) + })? + .finish() + })? + .orm::() + .next() + .transpose()? + .unwrap(); + assert_eq!(where_scalar.id, 3); + + let where_exists = database + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + e.exists_subquery(false, |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::user_id())?.eq(e.column(User::id())?))? + .project_scalar(Order::id())? + .finish() + }) + })? + .order_by(User::id())? + .finish() + })? + .orm::() + .collect::, _>>()?; + assert_eq!( + where_exists.iter().map(|user| user.id).collect::>(), + vec![1, 2] + ); + + let where_quantified = database + .bind(|ctx| { + ctx.from::()? + .filter(|e| { + let id = e.column(User::id())?; + id.in_subquery(|ctx| { + ctx.from::()? + .project_scalar(Order::user_id())? + .finish() + }) + })? + .order_by(User::id())? + .finish() + })? + .orm::() + .collect::, _>>()?; + assert_eq!( + where_quantified + .iter() + .map(|user| user.id) + .collect::>(), + vec![1, 2] + ); + + let project_scalar = database + .bind(|ctx| { + ctx.from::()? + .order_by(User::id())? + .project_value(|e| { + e.scalar_subquery(|ctx| { + ctx.from::()? + .project_value(|e| { + let id = e.column(User::id())?; + e.aggregate("max", vec![id]) + })? + .finish() + }) + })? + .finish() + })? + .project_value::() + .collect::, _>>()?; + assert_eq!(project_scalar, vec![3, 3, 3]); + + let join_subquery = database + .bind(|ctx| { + ctx.from::()? + .inner_join::(|e| { + e.exists_subquery(false, |ctx| { + ctx.from::()? + .filter(|e| e.column(Order::id())?.eq(1))? + .project_scalar(Order::id())? + .finish() + }) + })? + .finish() + }) + .and_then(|iter| iter.collect::, _>>().map(|_| ())); + assert!(join_subquery.is_err()); + + let group_by_subquery = database + .bind(|ctx| { + ctx.from::()? + .project_scalar(User::id())? + .group_by(|e| { + e.scalar_subquery(|ctx| { + ctx.from::()?.project_scalar(User::id())?.finish() + }) + })? + .finish() + }) + .and_then(|iter| iter.collect::, _>>().map(|_| ())); + assert!(group_by_subquery.is_err()); + + let having_subquery = database + .bind(|ctx| { + ctx.from::()? + .project_value(|e| { + let id = e.column(User::id())?; + e.aggregate("max", vec![id]) + })? + .having(|e| { + let max_id = e.scalar_subquery(|ctx| { + ctx.from::()?.project_scalar(User::id())?.finish() + })?; + e.column(User::id())?.eq(max_id) + })? + .finish() + }) + .and_then(|iter| iter.collect::, _>>().map(|_| ())); + assert!(having_subquery.is_err()); + + let sort_subquery = database + .bind(|ctx| { + ctx.from::()? + .order_by_expr(|e| { + Ok(e.scalar_subquery(|ctx| { + ctx.from::()?.project_scalar(User::id())?.finish() + })? + .asc()) + })? + .finish() + }) + .and_then(|iter| iter.collect::, _>>().map(|_| ())); + assert!(sort_subquery.is_err()); + + database.drop_table::()?; + database.drop_table::()?; + + Ok(()) + } + #[test] fn test_orm_model_lifecycle() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; + let (_temp_dir, mut database) = build_test_database()?; - database.create_table::()?; - database.run("drop index users.users_age_index")?.done()?; - database.create_table_if_not_exists::()?; - database.run("drop index users.users_age_index")?.done()?; - database.create_table_if_not_exists::()?; + create_model_table::(&mut database)?; + drop_model_index::(&mut database, "users_age_index")?; + create_model_table_if_not_exists::(&mut database)?; + drop_model_index::(&mut database, "users_age_index")?; + create_model_table_if_not_exists::(&mut database)?; let user = User { id: 1, @@ -1545,7 +2369,7 @@ mod test { cache: "".to_string(), })?; } - database.analyze::()?; + database.analyze_model::()?; let mut explain_iter = database.run("explain select age from users where age = 1050")?; let explain_rows = explain_iter.by_ref().collect::, _>>()?; @@ -1559,7 +2383,7 @@ mod test { .collect::>() .join("\n"); assert!( - explain_plan.contains("IndexScan By users_age_index"), + explain_plan.contains("IndexScan By #") && explain_plan.contains("Covered"), "unexpected explain plan: {explain_plan}" ); @@ -1570,18 +2394,27 @@ mod test { assert_eq!(defaulted.age, Some(18)); database - .from::() - .eq(User::id(), 1) - .update() - .set(User::name(), "Bob") - .set(User::age(), None::) - .execute()?; + .bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(User::id())?.eq(1))? + .update(|u| { + u.set_value(User::name(), "Bob")?; + u.set_value(User::age(), None::) + }) + })? + .done()?; let updated = database.get::(&1)?.unwrap(); assert_eq!(updated.name, "Bob"); assert_eq!(updated.age, None); - database.from::().eq(User::id(), 1).delete()?; + database + .bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(User::id())?.eq(1))? + .delete() + })? + .done()?; assert!(database.get::(&1)?.is_none()); database.insert(&User { @@ -1602,7 +2435,13 @@ mod test { Err(DatabaseError::DuplicateUniqueValue) )); - database.from::().eq(User::id(), 2).delete()?; + database + .bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(User::id())?.eq(2))? + .delete() + })? + .done()?; assert!(database.get::(&2)?.is_none()); database.drop_table::()?; @@ -1613,8 +2452,8 @@ mod test { #[test] fn test_orm_update_delete_builder() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; - database.create_table::()?; + let (_temp_dir, mut database) = build_test_database()?; + create_model_table::(&mut database)?; for (id, name, age) in [ (1, "Alice", Some(18)), @@ -1630,62 +2469,59 @@ mod test { } database - .from::() - .alias("u") - .eq(User::id().qualify("u"), 1) - .update() - .set(User::name(), "BuilderAlice") - .set(User::age(), None::) - .execute()?; + .bind(|ctx| { + ctx.mutate_as::("u")? + .filter(|e| e.qualified_column("u", User::id())?.eq(1))? + .update(|u| { + u.set_value(User::name(), "BuilderAlice")?; + u.set_value(User::age(), None::) + }) + })? + .done()?; let updated = database.get::(&1)?.unwrap(); assert_eq!(updated.name, "BuilderAlice"); assert_eq!(updated.age, None); database - .from::() - .eq(User::id(), 2) - .update() - .set_expr(User::age(), User::id().add(20)) - .execute()?; + .bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(User::id())?.eq(2))? + .update(|u| { + u.set_expr(User::age(), |e| { + let id = e.column(User::id())?; + e.binary(id, BinaryOperator::Plus, e.value(20)) + }) + }) + })? + .done()?; assert_eq!(database.get::(&2)?.unwrap().age, Some(22)); database - .from::() - .alias("u") - .eq(User::name().qualify("u"), "Carol") - .delete()?; + .bind(|ctx| { + ctx.mutate_as::("u")? + .filter(|e| e.qualified_column("u", User::name())?.eq("Carol"))? + .delete() + })? + .done()?; assert!(database.get::(&3)?.is_none()); - let empty_update = database.from::().eq(User::id(), 1).update().execute(); + let empty_update = database.bind(|ctx| { + ctx.mutate::()? + .filter(|e| e.column(User::id())?.eq(1))? + .update(|_| Ok(())) + }); assert!(matches!(empty_update, Err(DatabaseError::ColumnsEmpty))); - let ordered_delete = database.from::().asc(User::id()).delete(); - assert!(matches!( - ordered_delete, - Err(DatabaseError::UnsupportedStmt(message)) if message.contains("order by") - )); - - let limited_update = database - .from::() - .limit(1) - .update() - .set(User::name(), "ignored") - .execute(); - assert!(matches!( - limited_update, - Err(DatabaseError::UnsupportedStmt(message)) if message.contains("limit") - )); - Ok(()) } #[test] fn test_orm_insert_query_builder() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; - database.create_table::()?; - database.create_table::()?; + let (_temp_dir, mut database) = build_test_database()?; + create_model_table::(&mut database)?; + create_model_table::(&mut database)?; for (id, name, age) in [(1, "Alice", Some(18)), (2, "Bob", Some(19))] { database.insert(&ArchivedUser { @@ -1696,44 +2532,58 @@ mod test { })?; } - database.from::().insert::()?; + database + .bind(|ctx| { + ctx.insert_select::(std::iter::empty::(), |ctx| { + ctx.from::()?.finish() + }) + })? + .done()?; let inserted_users = database - .from::() - .asc(User::id()) - .fetch()? + .bind(|ctx| ctx.from::()?.order_by(User::id())?.finish())? + .orm::() .collect::, _>>()?; assert_eq!(inserted_users.len(), 2); assert_eq!(inserted_users[0].name, "Alice"); assert_eq!(inserted_users[1].name, "Bob"); - database.create_table::()?; + create_model_table::(&mut database)?; database - .from::() - .project_tuple((ArchivedUser::id(), ArchivedUser::name())) - .insert_into::(( - UserNameSnapshot::id(), - UserNameSnapshot::name(), - ))?; + .bind(|ctx| { + ctx.insert_select::(["id", "user_name"], |ctx| { + ctx.from::()? + .project_scalars((ArchivedUser::id(), ArchivedUser::name()))? + .finish() + }) + })? + .done()?; let snapshots = database - .from::() - .asc(UserNameSnapshot::id()) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .order_by(UserNameSnapshot::id())? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!(snapshots.len(), 2); assert_eq!(snapshots[0].name, "Alice"); assert_eq!(snapshots[1].name, "Bob"); database - .from::() - .eq(ArchivedUser::id(), 2) - .overwrite::()?; + .bind(|ctx| { + ctx.overwrite_select::(std::iter::empty::(), |ctx| { + ctx.from::()? + .filter(|e| e.column(ArchivedUser::id())?.eq(2))? + .finish() + }) + })? + .done()?; let overwritten_users = database - .from::() - .asc(User::id()) - .fetch()? + .bind(|ctx| ctx.from::()?.order_by(User::id())?.finish())? + .orm::() .collect::, _>>()?; assert_eq!(overwritten_users.len(), 2); assert_eq!(overwritten_users[0].id, 1); @@ -1742,18 +2592,23 @@ mod test { assert_eq!(overwritten_users[1].name, "Bob"); database - .from::() - .eq(ArchivedUser::id(), 1) - .project_tuple((ArchivedUser::id(), ArchivedUser::name())) - .overwrite_into::(( - UserNameSnapshot::id(), - UserNameSnapshot::name(), - ))?; + .bind(|ctx| { + ctx.overwrite_select::(["id", "user_name"], |ctx| { + ctx.from::()? + .filter(|e| e.column(ArchivedUser::id())?.eq(1))? + .project_scalars((ArchivedUser::id(), ArchivedUser::name()))? + .finish() + }) + })? + .done()?; let overwritten_snapshots = database - .from::() - .asc(UserNameSnapshot::id()) - .fetch()? + .bind(|ctx| { + ctx.from::()? + .order_by(UserNameSnapshot::id())? + .finish() + })? + .orm::() .collect::, _>>()?; assert_eq!(overwritten_snapshots.len(), 2); assert_eq!(overwritten_snapshots[0].id, 1); @@ -1766,8 +2621,8 @@ mod test { #[test] fn test_orm_extended_write_and_ddl_helpers() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; - database.create_table::()?; + let (_temp_dir, mut database) = build_test_database()?; + create_model_table::(&mut database)?; database.insert_many([ User { @@ -1791,33 +2646,48 @@ mod test { ])?; let ages_nulls_first = database - .from::() - .project_value(User::age()) - .asc(User::age()) - .nulls_first() - .fetch::>()? + .bind(|ctx| { + ctx.from::()? + .project_scalar(User::age())? + .order_by(User::age().nulls_first())? + .finish() + })? + .project_value::>() .collect::, _>>()?; assert_eq!(ages_nulls_first, vec![None, Some(18), Some(30)]); let ages_nulls_last = database - .from::() - .project_value(User::age()) - .asc(User::age()) - .nulls_last() - .fetch::>()? + .bind(|ctx| { + ctx.from::()? + .project_scalar(User::age())? + .order_by(User::age())? + .finish() + })? + .project_value::>() .collect::, _>>()?; assert_eq!(ages_nulls_last, vec![Some(18), Some(30), None]); - let set_query_ages = database - .from::() - .project_value(User::age()) - .eq(User::id(), 1) - .union(database.from::().project_value(User::age())) - .all() - .desc(User::age()) - .nulls_first() - .fetch::>()? + let mut set_query_ages = database + .bind(|ctx| { + ctx.union( + true, + |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(1))? + .project_scalar(User::age())? + .finish() + }, + |ctx| ctx.from::()?.project_scalar(User::age())?.finish(), + ) + })? + .project_value::>() .collect::, _>>()?; + set_query_ages.sort_by(|left, right| match (left, right) { + (None, None) => std::cmp::Ordering::Equal, + (None, Some(_)) => std::cmp::Ordering::Less, + (Some(_), None) => std::cmp::Ordering::Greater, + (Some(left), Some(right)) => right.cmp(left), + }); assert_eq!(set_query_ages, vec![None, Some(30), Some(18), Some(18)]); let mut tx = database.new_transaction()?; @@ -1830,9 +2700,8 @@ mod test { tx.commit()?; let updated_users = database - .from::() - .asc(User::id()) - .fetch()? + .bind(|ctx| ctx.from::()?.order_by(User::id())?.finish())? + .orm::() .collect::, _>>()?; assert_eq!( updated_users @@ -1842,12 +2711,11 @@ mod test { vec![(1, "Alice"), (2, "Bob"), (3, "Carol"), (4, "Dora")] ); - database.create_view( - "user_names", - database - .from::() - .project_tuple((User::id(), User::name())), - )?; + database.create_view("user_names", |ctx| { + ctx.from::()? + .project_scalars((User::id(), User::name()))? + .finish() + })?; let mut view_rows = database .run("select * from user_names")? @@ -1857,13 +2725,12 @@ mod test { assert_eq!(view_rows.len(), 4); assert_eq!(view_rows[0].name, "Alice"); - database.create_or_replace_view( - "user_names", - database - .from::() - .eq(User::id(), 2) - .project_tuple((User::id(), User::name())), - )?; + database.create_or_replace_view("user_names", |ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(2))? + .project_scalars((User::id(), User::name()))? + .finish() + })?; let replaced_view_rows = database .run("select * from user_names")? @@ -1877,22 +2744,27 @@ mod test { database.drop_view_if_exists("user_names")?; database.truncate::()?; - assert_eq!(database.from::().count()?, 0); + let count = database + .bind(|ctx| ctx.from::()?.count())? + .project_value::() + .next() + .transpose()? + .unwrap() as usize; + assert_eq!(count, 0); Ok(()) } #[test] fn test_orm_introspection_helpers() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; - database.create_table::()?; - database.create_table::()?; - database.create_view( - "user_names", - database - .from::() - .project_tuple((User::id(), User::name())), - )?; + let (_temp_dir, mut database) = build_test_database()?; + create_model_table::(&mut database)?; + create_model_table::(&mut database)?; + database.create_view("user_names", |ctx| { + ctx.from::()? + .project_scalars((User::id(), User::name()))? + .finish() + })?; let tables = database.show_tables()?.collect::, _>>()?; assert!(tables.iter().any(|name| name == "users")); @@ -1916,25 +2788,28 @@ mod test { == Some("PRIMARY") ); - let plan = database - .from::() - .eq(User::id(), 1) - .project_value(User::name()) - .explain()?; - assert_eq!( - plan, - "Projection [users.user_name] [Project => (Sort Option: Follow)]\n Filter (users.id = 1), Is Having: false [Filter => (Sort Option: Follow)]\n TableScan users -> [id, user_name] [SeqScan => (Sort Option: None)]" - ); + let plan = database.explain(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(1))? + .project_scalar(User::name())? + .finish() + })?; + assert!(plan.contains("Projection")); + assert!(plan.contains("Filter (")); + assert!(plan.contains(" = 1")); + assert!(plan.contains("TableScan users -> [#")); - let set_plan = database - .from::() - .project_value(User::id()) - .union(database.from::().project_value(Wallet::id())) - .explain()?; - assert_eq!( - set_plan, - "Aggregate [] -> Group By [users.id] [HashAggregate => (Sort Option: None)]\n Union: [id]\n Projection [users.id] [Project => (Sort Option: Follow)]\n TableScan users -> [id] [SeqScan => (Sort Option: None)]\n Projection [wallets.id] [Project => (Sort Option: Follow)]\n TableScan wallets -> [id] [SeqScan => (Sort Option: None)]" - ); + let set_plan = database.explain(|ctx| { + ctx.union( + false, + |ctx| ctx.from::()?.project_scalar(User::id())?.finish(), + |ctx| ctx.from::()?.project_scalar(Wallet::id())?.finish(), + ) + })?; + assert!(set_plan.contains("Aggregate")); + assert!(set_plan.contains("Union: [#")); + assert!(set_plan.contains("TableScan users -> [#")); + assert!(set_plan.contains("TableScan wallets -> [#")); let mut tx = database.new_transaction()?; let tx_tables = tx.show_tables()?.collect::, _>>()?; @@ -1948,9 +2823,9 @@ mod test { #[test] fn test_orm_drop_index() -> Result<(), DatabaseError> { - let (_temp_dir, database) = build_test_database()?; + let (_temp_dir, mut database) = build_test_database()?; - database.create_table::()?; + create_model_table::(&mut database)?; database.insert(&User { id: 1, name: "Alice".to_string(), @@ -1969,10 +2844,10 @@ mod test { Err(DatabaseError::DuplicateUniqueValue) )); - database.drop_index::("users_age_index")?; - database.drop_index_if_exists::("users_age_index")?; + drop_model_index::(&mut database, "users_age_index")?; + drop_model_index_if_exists::(&mut database, "users_age_index")?; - database.drop_index::("uk_user_name_index")?; + drop_model_index::(&mut database, "uk_user_name_index")?; database.insert(&User { id: 2, @@ -1981,8 +2856,8 @@ mod test { cache: "".to_string(), })?; - database.drop_index::("users_name_age_index")?; - database.drop_index_if_exists::("users_name_age_index")?; + drop_model_index::(&mut database, "users_name_age_index")?; + drop_model_index_if_exists::(&mut database, "users_name_age_index")?; assert!(matches!( database.drop_index::("pk_index"), @@ -2069,15 +2944,19 @@ mod test { ); assert!(numbers.next().is_none()); - let function_schema = function.output_schema(); let table_name: TableName = "test_numbers".to_string().into(); + let mut table_arena = TableArena::default(); + let mut function_schema = Schema::new(); + function.output_schema_into(&table_name, &mut table_arena, &mut function_schema); + let c1_ref = function_schema[0]; + let c2_ref = function_schema[1]; let mut c1 = ColumnCatalog::new( "c1".to_string(), true, ColumnDesc::new(LogicalType::Integer, None, false, None)?, ); c1.summary_mut().relation = ColumnRelation::Table { - column_id: function_schema[0].id().unwrap(), + column_id: table_arena.column(c1_ref).id().unwrap(), table_name: table_name.clone(), is_temp: false, }; @@ -2087,15 +2966,13 @@ mod test { ColumnDesc::new(LogicalType::Integer, None, false, None)?, ); c2.summary_mut().relation = ColumnRelation::Table { - column_id: function_schema[1].id().unwrap(), + column_id: table_arena.column(c2_ref).id().unwrap(), table_name: table_name.clone(), is_temp: false, }; - assert_eq!( - function_schema, - &Arc::new(vec![ColumnRef::from(c1), ColumnRef::from(c2)]) - ); + assert_eq!(table_arena.column(c1_ref), &c1); + assert_eq!(table_arena.column(c2_ref), &c2); Ok(()) } diff --git a/tests/slt/alter_table.slt b/tests/slt/alter_table.slt index 62f03c8e..4bdb0269 100644 --- a/tests/slt/alter_table.slt +++ b/tests/slt/alter_table.slt @@ -49,7 +49,7 @@ create table t2(id int primary key, v1 int) statement ok insert into t2 values (1,1) -statement error +statement ok alter table t2 add column v2 int default 0 unique query IIII rowsort @@ -69,11 +69,11 @@ create table t3(id int primary key, v int) statement ok insert into t3 values (1, 10), (2, 20) -query TTTTI +query TTTTTTT describe t3 ---- -id Integer 4 false PRIMARY null -v Integer 4 true EMPTY null +id Integer 4 false PRIMARY null #1 +v Integer 4 true EMPTY null #2 statement ok insert into t3 values (3, 10) @@ -104,11 +104,11 @@ select * from t3 3 10 4 10 -query TTTTI +query TTTTTTT describe t3 ---- -id Integer 4 false PRIMARY null -v Integer 4 true EMPTY null +id Integer 4 false PRIMARY null #1 +v Integer 4 true EMPTY null #2 statement error alter table t3 modify column v int first diff --git a/tests/slt/change_column.slt b/tests/slt/change_column.slt index 7b7c8269..753f196d 100644 --- a/tests/slt/change_column.slt +++ b/tests/slt/change_column.slt @@ -13,7 +13,7 @@ select value1 from alter_users order by id 11 22 -statement error +statement ok alter table alter_users alter column value1 type int statement error @@ -22,12 +22,12 @@ alter table alter_users alter column id type bigint statement ok alter table alter_users alter column v2 type bigint -query TTTTI +query TTTTTTT describe alter_users ---- -id Integer 4 false PRIMARY null -value1 Integer 4 true UNIQUE null -v2 Bigint 8 true EMPTY null +id Integer 4 false PRIMARY null #1 +value1 Integer 4 true UNIQUE null #4 +v2 Bigint 8 true EMPTY null #2 query I select v2 from alter_users order by id @@ -38,12 +38,12 @@ select v2 from alter_users order by id statement ok alter table alter_users change column v2 value2 bigint -query TTTTI +query TTTTTTT describe alter_users ---- -id Integer 4 false PRIMARY null -value1 Integer 4 true UNIQUE null -value2 Bigint 8 true EMPTY null +id Integer 4 false PRIMARY null #1 +value1 Integer 4 true UNIQUE null #4 +value2 Bigint 8 true EMPTY null #3 statement ok alter table alter_users alter column value2 set default 999 @@ -76,12 +76,12 @@ update alter_users set value2 = 404 where id = 4 statement ok alter table alter_users alter column value2 set not null -query TTTTI +query TTTTTTT describe alter_users ---- -id Integer 4 false PRIMARY null -value1 Integer 4 true UNIQUE null -value2 Bigint 8 false EMPTY null +id Integer 4 false PRIMARY null #1 +value1 Integer 4 true UNIQUE null #4 +value2 Bigint 8 false EMPTY null #2 statement error insert into alter_users (id, value1) values (5, 55) @@ -100,12 +100,12 @@ null statement ok alter table alter_users modify column value2 int -query TTTTI +query TTTTTTT describe alter_users ---- -id Integer 4 false PRIMARY null -value1 Integer 4 true UNIQUE null -value2 Integer 4 true EMPTY null +id Integer 4 false PRIMARY null #1 +value1 Integer 4 true UNIQUE null #4 +value2 Integer 4 true EMPTY null #2 query I select value2 from alter_users where value2 is not null order by id @@ -121,12 +121,12 @@ update alter_users set value2 = 606 where id = 6 statement ok alter table alter_users change column value2 value2 int not null -query TTTTI +query TTTTTTT describe alter_users ---- -id Integer 4 false PRIMARY null -value1 Integer 4 true UNIQUE null -value2 Integer 4 false EMPTY null +id Integer 4 false PRIMARY null #1 +value1 Integer 4 true UNIQUE null #4 +value2 Integer 4 false EMPTY null #3 statement error insert into alter_users (id, value1) values (7, 77) @@ -142,12 +142,12 @@ select value2 from alter_users where id = 7 ---- 707 -query TTTTI +query TTTTTTT describe alter_users ---- -id Integer 4 false PRIMARY null -value1 Integer 4 true UNIQUE null -value2 Integer 4 false EMPTY 707 +id Integer 4 false PRIMARY null #1 +value1 Integer 4 true UNIQUE null #4 +value2 Integer 4 false EMPTY 707 #2 statement ok alter table alter_users modify column value2 int null @@ -160,22 +160,22 @@ select value2 from alter_users where id = 8 ---- 707 -query TTTTI +query TTTTTTT describe alter_users ---- -id Integer 4 false PRIMARY null -value1 Integer 4 true UNIQUE null -value2 Integer 4 true EMPTY 707 +id Integer 4 false PRIMARY null #1 +value1 Integer 4 true UNIQUE null #4 +value2 Integer 4 true EMPTY 707 #3 statement ok alter table alter_users modify column value2 int not null -query TTTTI +query TTTTTTT describe alter_users ---- -id Integer 4 false PRIMARY null -value1 Integer 4 true UNIQUE null -value2 Integer 4 false EMPTY 707 +id Integer 4 false PRIMARY null #1 +value1 Integer 4 true UNIQUE null #4 +value2 Integer 4 false EMPTY 707 #2 statement ok alter table alter_users modify column value2 int default 808 @@ -188,12 +188,12 @@ select value2 from alter_users where id = 9 ---- 808 -query TTTTI +query TTTTTTT describe alter_users ---- -id Integer 4 false PRIMARY null -value1 Integer 4 true UNIQUE null -value2 Integer 4 false EMPTY 808 +id Integer 4 false PRIMARY null #1 +value1 Integer 4 true UNIQUE null #4 +value2 Integer 4 false EMPTY 808 #3 statement ok drop table alter_users diff --git a/tests/slt/create_index.slt b/tests/slt/create_index.slt index 666032ff..9a748f50 100644 --- a/tests/slt/create_index.slt +++ b/tests/slt/create_index.slt @@ -36,8 +36,8 @@ drop index t.index_2 statement error drop index t.pk_index -statement error +statement ok drop index t.index_3 statement ok -drop table t \ No newline at end of file +drop table t diff --git a/tests/slt/describe.slt b/tests/slt/describe.slt index 39fc34b6..8fb3ccb9 100644 --- a/tests/slt/describe.slt +++ b/tests/slt/describe.slt @@ -1,12 +1,12 @@ statement ok create table t9 (c1 int primary key, c2 int default 0, c3 varchar unique); -query TTTTI +query TTTTTTT describe t9; ---- -c1 Integer 4 false PRIMARY null -c2 Integer 4 true EMPTY 0 -c3 Varchar(None, CHARACTERS) variable true UNIQUE null +c1 Integer 4 false PRIMARY null #1 +c2 Integer 4 true EMPTY 0 #2 +c3 Varchar(None, CHARACTERS) variable true UNIQUE null #3 statement ok drop table t9; @@ -14,12 +14,12 @@ drop table t9; statement ok create table t9_m (c1 int primary key, c2 int primary key, c3 varchar unique); -query TTTTI +query TTTTTTT describe t9_m; ---- -c1 Integer 4 false PRIMARY null -c2 Integer 4 false PRIMARY null -c3 Varchar(None, CHARACTERS) variable true UNIQUE null +c1 Integer 4 false PRIMARY null #1 +c2 Integer 4 false PRIMARY null #2 +c3 Varchar(None, CHARACTERS) variable true UNIQUE null #3 statement ok drop table t9_m; diff --git a/tests/slt/set_operation.slt b/tests/slt/set_operation.slt index 84d0a0f2..eeba3617 100644 --- a/tests/slt/set_operation.slt +++ b/tests/slt/set_operation.slt @@ -112,6 +112,11 @@ select v from set_left intersect select v from set_right order by v desc 1 null +# Regression: branch table names from a set operation must not leak to the +# top-level ORDER BY scope. +statement error +select v from set_left union select v from set_right order by set_left.v + statement ok drop table set_right diff --git a/tests/slt/stream_distinct_explain.slt b/tests/slt/stream_distinct_explain.slt index 4dc32b46..6db0069a 100644 --- a/tests/slt/stream_distinct_explain.slt +++ b/tests/slt/stream_distinct_explain.slt @@ -14,7 +14,7 @@ analyze table distinct_t; query T explain select distinct c1 from distinct_t where c1 < 10 and c1 > 0; ---- -Projection [distinct_t.c1] [Project => (Sort Option: Follow)] Aggregate [] -> Group By [distinct_t.c1] [StreamDistinct => (Sort Option: Follow)] Filter ((distinct_t.c1 < 10) && (distinct_t.c1 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan distinct_t -> [c1] [IndexScan By distinct_t_c1_index => (0, 10) Covered => (Sort Option: OrderBy: (distinct_t.c1 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#2] [Project => (Sort Option: Follow)] Aggregate [] -> Group By [#2] [StreamDistinct => (Sort Option: Follow)] Filter ((#2 < 10) && (#2 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan distinct_t -> [#2] [IndexScan By #1 => (0, 10) Covered => (Sort Option: OrderBy: (#2 Asc Nulls Last) ignore_prefix_len: 0)] statement ok drop index distinct_t.distinct_t_c1_index; @@ -23,7 +23,7 @@ drop index distinct_t.distinct_t_c1_index; query T explain select distinct c1 from distinct_t where c1 < 10 and c1 > 0; ---- -Projection [distinct_t.c1] [Project => (Sort Option: Follow)] Aggregate [] -> Group By [distinct_t.c1] [HashAggregate => (Sort Option: None)] Filter ((distinct_t.c1 < 10) && (distinct_t.c1 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan distinct_t -> [c1] [SeqScan => (Sort Option: None)] +Projection [#2] [Project => (Sort Option: Follow)] Aggregate [] -> Group By [#2] [HashAggregate => (Sort Option: None)] Filter ((#2 < 10) && (#2 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan distinct_t -> [#2] [SeqScan => (Sort Option: None)] statement ok drop table distinct_t; diff --git a/tests/slt/subquery.slt b/tests/slt/subquery.slt index 2974cdb4..51d66ccf 100644 --- a/tests/slt/subquery.slt +++ b/tests/slt/subquery.slt @@ -16,6 +16,9 @@ select x.a, x.b from (select a, b from t1) as x; 1 2 3 4 +statement error +select t1.a from (select a from t1) as x; + query II select * from (select a, b from t1); ---- @@ -280,7 +283,7 @@ where exists ( select 1 from orders where orders.user_id = users.id ); ---- -Projection [users.id] [Project => (Sort Option: Follow)] Filter _temp_table_0_.true, Is Having: false [Filter => (Sort Option: Follow)] MarkExistsApply TableScan users -> [id] [SeqScan => (Sort Option: None)] TableScan orders -> [user_id] [IndexScan By orders_user_id_index => Probe ? => (Sort Option: OrderBy: (orders.user_id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1] [Project => (Sort Option: Follow)] Filter #15, Is Having: false [Filter => (Sort Option: Follow)] MarkExistsApply TableScan users -> [#1, #2] [SeqScan => (Sort Option: None)] TableScan orders -> [#3, #4, #5] [IndexScan By #16 => Probe ? => (Sort Option: OrderBy: (#4 Asc Nulls Last) ignore_prefix_len: 0)] query I rowsort select id from users diff --git a/tests/slt/update.slt b/tests/slt/update.slt index acf234da..7d81bb4f 100644 --- a/tests/slt/update.slt +++ b/tests/slt/update.slt @@ -55,5 +55,13 @@ select * from t 3 3 9 4 4 4 9 5 +statement ok +update t set (v2, v3) = (v1 + 10, v3 + 20) where id = 2 + +query IIII +select * from t where id = 2 +---- +2 2 12 23 + statement ok drop table t diff --git a/tests/slt/view.slt b/tests/slt/view.slt index a62c9764..a8ea62f9 100644 --- a/tests/slt/view.slt +++ b/tests/slt/view.slt @@ -22,6 +22,12 @@ create or replace view v1 (c0, c1, c2) as select * from t1 statement ok create view v2 as select * from t1 where a != 1 +statement ok +create view v_alias as select a as aa, b from t1 + +statement error +create view v_bad (c0, c1) as select a from t1 + query III select * from v1 ---- @@ -35,6 +41,13 @@ select * from v2 0 0 0 0 2 2 2 2 +query I +select aa from v_alias order by aa +---- +0 +1 +2 + query IIIIIII select * from v1 left join v2 ---- diff --git a/tests/slt/where_by_index_explain.slt b/tests/slt/where_by_index_explain.slt index 87785219..ae05516b 100644 --- a/tests/slt/where_by_index_explain.slt +++ b/tests/slt/where_by_index_explain.slt @@ -22,127 +22,127 @@ analyze table t1; query T explain select * from t1 limit 10; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2], Limit: 10 [SeqScan => (Sort Option: None)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3], Limit: 10 [SeqScan => (Sort Option: None)] query T explain select * from t1 where id = 0; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id = 0), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => 0 => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter (#1 = 0), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => 0 => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where id = 0 and id = 1; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id = 0) && (t1.id = 1)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => Dummy => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#1 = 0) && (#1 = 1)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => Dummy => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where id = 0 and id != 0; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id = 0) && (t1.id != 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => 0 => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#1 = 0) && (#1 != 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => 0 => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where id = 0 or id != 0 limit 10; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter ((t1.id = 0) || (t1.id != 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [SeqScan => (Sort Option: None)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter ((#1 = 0) || (#1 != 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [SeqScan => (Sort Option: None)] query T explain select * from t1 where id = 0 and id != 0 and id = 3; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (((t1.id = 0) && (t1.id != 0)) && (t1.id = 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => Dummy => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter (((#1 = 0) && (#1 != 0)) && (#1 = 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => Dummy => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where id = 0 and id != 0 or id = 3; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (((t1.id = 0) && (t1.id != 0)) || (t1.id = 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => 0, 3 => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter (((#1 = 0) && (#1 != 0)) || (#1 = 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => 0, 3 => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where id > 0 and id = 3; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id > 0) && (t1.id = 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => 3 => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#1 > 0) && (#1 = 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => 3 => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where id >= 0 and id <= 3; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id >= 0) && (t1.id <= 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => [0, 3] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#1 >= 0) && (#1 <= 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => [0, 3] => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where id <= 0 and id >= 3; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id <= 0) && (t1.id >= 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => Dummy => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#1 <= 0) && (#1 >= 3)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => Dummy => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where (id > 10) = false; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 10] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter (#1 <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => (-inf, 10] => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where (id > 10) != true; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 10] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter (#1 <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => (-inf, 10] => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where not (id > 10); ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.id <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 10] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter (#1 <= 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => (-inf, 10] => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where id >= 3 or id <= 9 limit 10; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter ((t1.id >= 3) || (t1.id <= 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [SeqScan => (Sort Option: None)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter ((#1 >= 3) || (#1 <= 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [SeqScan => (Sort Option: None)] query T explain select * from t1 where id <= 3 or id >= 9 limit 10; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter ((t1.id <= 3) || (t1.id >= 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => (-inf, 3], [9, +inf) => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter ((#1 <= 3) || (#1 >= 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => (-inf, 3], [9, +inf) => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where (id >= 0 and id <= 3) or (id >= 9 and id <= 12); ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (((t1.id >= 0) && (t1.id <= 3)) || ((t1.id >= 9) && (t1.id <= 12))), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => [0, 3], [9, 12] => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter (((#1 >= 0) && (#1 <= 3)) || ((#1 >= 9) && (#1 <= 12))), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => [0, 3], [9, 12] => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where (id >= 0 or id <= 3) and (id >= 9 or id <= 12) limit 10; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter (((t1.id >= 0) || (t1.id <= 3)) && ((t1.id >= 9) || (t1.id <= 12))), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [SeqScan => (Sort Option: None)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Limit 10 [Limit => (Sort Option: Follow)] Filter (((#1 >= 0) || (#1 <= 3)) && ((#1 >= 9) || (#1 <= 12))), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [SeqScan => (Sort Option: None)] query T explain select * from t1 where id = 5 or (id > 5 and (id > 6 or id < 8) and id < 12); ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.id = 5) || (((t1.id > 5) && ((t1.id > 6) || (t1.id < 8))) && (t1.id < 12))), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By pk_index => [5, 12) => (Sort Option: OrderBy: (t1.id Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#1 = 5) || (((#1 > 5) && ((#1 > 6) || (#1 < 8))) && (#1 < 12))), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #0 => [5, 12) => (Sort Option: OrderBy: (#1 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where c1 = 7 and c2 = 8; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c1 = 7) && (t1.c2 = 8)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By u_c1_index => 7 => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#2 = 7) && (#3 = 8)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #1 => 7 => (Sort Option: OrderBy: (#2 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where c1 = 7 and c2 < 9; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c1 = 7) && (t1.c2 < 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By u_c1_index => 7 => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#2 = 7) && (#3 < 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #1 => 7 => (Sort Option: OrderBy: (#2 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where (c1 = 7 or c1 = 10) and c2 < 9; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (((t1.c1 = 7) || (t1.c1 = 10)) && (t1.c2 < 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By u_c1_index => 7, 10 => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter (((#2 = 7) || (#2 = 10)) && (#3 < 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #1 => 7, 10 => (Sort Option: OrderBy: (#2 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where c1 is null and c2 is null; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.c1 is null && t1.c2 is null), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By c2_index => null => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter (#2 is null && #3 is null), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #2 => null => (Sort Option: OrderBy: (#3 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where c1 > 0 and c1 < 8; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c1 > 0) && (t1.c1 < 8)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By u_c1_index => (0, 8) => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#2 > 0) && (#2 < 8)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #1 => (0, 8) => (Sort Option: OrderBy: (#2 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where c2 > 0 and c2 < 9; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c2 > 0) && (t1.c2 < 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By c2_index => (0, 9) => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#3 > 0) && (#3 < 9)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #2 => (0, 9) => (Sort Option: OrderBy: (#3 Asc Nulls Last) ignore_prefix_len: 0)] query T explain select * from t1 where c2 = 5; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter (t1.c2 = 5), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By c2_index => 5 => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter (#3 = 5), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #2 => 5 => (Sort Option: OrderBy: (#3 Asc Nulls Last) ignore_prefix_len: 0)] statement ok update t1 set c2 = 9 where c1 = 1 @@ -150,7 +150,7 @@ update t1 set c2 = 9 where c1 = 1 query T explain select * from t1 where c2 > 0 and c2 < 10; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c2 > 0) && (t1.c2 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By c2_index => (0, 10) => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#3 > 0) && (#3 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #2 => (0, 10) => (Sort Option: OrderBy: (#3 Asc Nulls Last) ignore_prefix_len: 0)] statement ok delete from t1 where c1 = 4 @@ -158,20 +158,20 @@ delete from t1 where c1 = 4 query T explain select * from t1 where c2 > 0 and c2 < 10; ---- -Projection [t1.id, t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c2 > 0) && (t1.c2 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1, c2] [IndexScan By c2_index => (0, 10) => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#1, #2, #3] [Project => (Sort Option: Follow)] Filter ((#3 > 0) && (#3 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2, #3] [IndexScan By #2 => (0, 10) => (Sort Option: OrderBy: (#3 Asc Nulls Last) ignore_prefix_len: 0)] # unique covered query T explain select c1 from t1 where c1 < 10; ---- -Projection [t1.c1] [Project => (Sort Option: Follow)] Filter (t1.c1 < 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [c1] [IndexScan By p_index => (-inf, (10)) Covered => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last, t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#2] [Project => (Sort Option: Follow)] Filter (#2 < 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#2] [IndexScan By #3 => (-inf, (10)) Covered => (Sort Option: OrderBy: (#2 Asc Nulls Last, #3 Asc Nulls Last) ignore_prefix_len: 0)] # unique covered with primary key projection query T explain select c1, id from t1 where c1 < 10; ---- -Projection [t1.c1, t1.id] [Project => (Sort Option: Follow)] Filter (t1.c1 < 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [id, c1] [IndexScan By p_index => (-inf, (10)) => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last, t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#2, #1] [Project => (Sort Option: Follow)] Filter (#2 < 10), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#1, #2] [IndexScan By #3 => (-inf, (10)) => (Sort Option: OrderBy: (#2 Asc Nulls Last, #3 Asc Nulls Last) ignore_prefix_len: 0)] statement ok drop index t1.u_c1_index; @@ -180,7 +180,7 @@ drop index t1.u_c1_index; query T explain select c2 from t1 where c2 < 10 and c2 > 0; ---- -Projection [t1.c2] [Project => (Sort Option: Follow)] Filter ((t1.c2 < 10) && (t1.c2 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [c2] [IndexScan By c2_index => (0, 10) Covered => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#3] [Project => (Sort Option: Follow)] Filter ((#3 < 10) && (#3 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#3] [IndexScan By #2 => (0, 10) Covered => (Sort Option: OrderBy: (#3 Asc Nulls Last) ignore_prefix_len: 0)] statement ok insert into t1 values(100000002, 100000002, 8); @@ -189,7 +189,7 @@ insert into t1 values(100000002, 100000002, 8); query T explain select distinct c2 from t1 where c2 < 10 and c2 > 0; ---- -Projection [t1.c2] [Project => (Sort Option: Follow)] Aggregate [] -> Group By [t1.c2] [StreamDistinct => (Sort Option: Follow)] Filter ((t1.c2 < 10) && (t1.c2 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [c2] [IndexScan By c2_index => (0, 10) Covered => (Sort Option: OrderBy: (t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#3] [Project => (Sort Option: Follow)] Aggregate [] -> Group By [#3] [StreamDistinct => (Sort Option: Follow)] Filter ((#3 < 10) && (#3 > 0)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#3] [IndexScan By #2 => (0, 10) Covered => (Sort Option: OrderBy: (#3 Asc Nulls Last) ignore_prefix_len: 0)] statement ok delete from t1 where id = 100000002; @@ -201,13 +201,13 @@ drop index t1.c2_index; query T explain select c1, c2 from t1 where c1 < 10 and c1 > 0 and c2 >0 and c2 < 10; ---- -Projection [t1.c1, t1.c2] [Project => (Sort Option: Follow)] Filter ((((t1.c1 < 10) && (t1.c1 > 0)) && (t1.c2 > 0)) && (t1.c2 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [c1, c2] [IndexScan By p_index => ((0), (10)) Covered => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last, t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#2, #3] [Project => (Sort Option: Follow)] Filter ((((#2 < 10) && (#2 > 0)) && (#3 > 0)) && (#3 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#2, #3] [IndexScan By #3 => ((0), (10)) Covered => (Sort Option: OrderBy: (#2 Asc Nulls Last, #3 Asc Nulls Last) ignore_prefix_len: 0)] # composite covered projection reorder query T explain select c2, c1 from t1 where c1 < 10 and c1 > 0 and c2 > 0 and c2 < 10; ---- -Projection [t1.c2, t1.c1] [Project => (Sort Option: Follow)] Filter ((((t1.c1 < 10) && (t1.c1 > 0)) && (t1.c2 > 0)) && (t1.c2 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [c1, c2] [IndexScan By p_index => ((0), (10)) Covered => (Sort Option: OrderBy: (t1.c1 Asc Nulls Last, t1.c2 Asc Nulls Last) ignore_prefix_len: 0)] +Projection [#3, #2] [Project => (Sort Option: Follow)] Filter ((((#2 < 10) && (#2 > 0)) && (#3 > 0)) && (#3 < 10)), Is Having: false [Filter => (Sort Option: Follow)] TableScan t1 -> [#2, #3] [IndexScan By #3 => ((0), (10)) Covered => (Sort Option: OrderBy: (#2 Asc Nulls Last, #3 Asc Nulls Last) ignore_prefix_len: 0)] statement ok @@ -230,7 +230,7 @@ create index idx_cover on t_cover (c1, c2, c3); query T explain select c2, c3 from t_cover where c1 = 2; ---- -Projection [t_cover.c2, t_cover.c3] [Project => (Sort Option: Follow)] Filter (t_cover.c1 = 2), Is Having: false [Filter => (Sort Option: Follow)] TableScan t_cover -> [c1, c2, c3] [SeqScan => (Sort Option: None)] +Projection [#3, #4] [Project => (Sort Option: Follow)] Filter (#2 = 2), Is Having: false [Filter => (Sort Option: Follow)] TableScan t_cover -> [#2, #3, #4] [SeqScan => (Sort Option: None)] statement ok drop table t_cover; diff --git a/tests/sqllogictest/Cargo.toml b/tests/sqllogictest/Cargo.toml index ec451fc0..675dbb06 100644 --- a/tests/sqllogictest/Cargo.toml +++ b/tests/sqllogictest/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] clap = { version = "4" } -"kite_sql" = { path = "../.." } +"kite_sql" = { path = "../..", features = ["copy", "decimal", "orm"] } glob = { version = "0.3" } sqllogictest = { version = "0.14" } -tempfile = { version = "3.10" } \ No newline at end of file +tempfile = { version = "3.10" } diff --git a/tests/sqllogictest/src/lib.rs b/tests/sqllogictest/src/lib.rs index 75147ba9..17dd9785 100644 --- a/tests/sqllogictest/src/lib.rs +++ b/tests/sqllogictest/src/lib.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use kite_sql::db::Database; +use kite_sql::binder::{command_type, CommandType}; +use kite_sql::db::{prepare_all, Database, DatabaseIter, Statement}; use kite_sql::errors::DatabaseError; use kite_sql::storage::rocksdb::RocksStorage; use sqllogictest::{DBOutput, DefaultColumnType, DB}; @@ -28,25 +29,75 @@ impl DB for SQLBase { fn run(&mut self, sql: &str) -> Result, Self::Error> { let start = Instant::now(); - let mut iter = self.db.run(sql)?; println!("|— Input SQL: {}", sql); - let types = vec![DefaultColumnType::Any; iter.schema().len()]; - let mut rows = Vec::new(); - - while let Some(tuple) = iter.next_borrowed_tuple()? { - rows.push( - tuple - .values - .iter() - .map(|value| format!("{}", value)) - .collect(), - ); + let mut statements = prepare_all(sql)?.into_iter().peekable(); + + while let Some(statement) = statements.next() { + let is_last = statements.peek().is_none(); + match command_type(&statement)? { + CommandType::DDL => { + self.db.ddl(statement.to_string())?; + if is_last { + println!(" |— time spent: {:?}", start.elapsed()); + return Ok(DBOutput::StatementComplete(0)); + } + } + CommandType::Analyze => { + execute_analyze_statement(&mut self.db, &statement)?; + if is_last { + println!(" |— time spent: {:?}", start.elapsed()); + return Ok(DBOutput::StatementComplete(0)); + } + } + _ => { + let iter = (&self.db).execute(statement, &[])?; + if is_last { + let output = collect_output(iter)?; + println!(" |— time spent: {:?}", start.elapsed()); + return Ok(output); + } + iter.done()?; + } + } } - iter.done()?; + println!(" |— time spent: {:?}", start.elapsed()); - if rows.is_empty() { - return Ok(DBOutput::StatementComplete(0)); - } - Ok(DBOutput::Rows { types, rows }) + Ok(DBOutput::StatementComplete(0)) + } +} + +fn collect_output( + mut iter: DatabaseIter<'_, RocksStorage>, +) -> Result, DatabaseError> { + let types = vec![DefaultColumnType::Any; iter.schema(|schema| schema.len())]; + let mut rows = Vec::new(); + + while let Some(tuple) = iter.next_borrowed_tuple()? { + rows.push( + tuple + .values + .iter() + .map(|value| format!("{}", value)) + .collect(), + ); + } + iter.done()?; + if rows.is_empty() { + return Ok(DBOutput::StatementComplete(0)); } + Ok(DBOutput::Rows { types, rows }) +} + +fn execute_analyze_statement( + db: &mut Database, + statement: &Statement, +) -> Result<(), DatabaseError> { + let Statement::Analyze(analyze) = statement else { + unreachable!("execute_analyze_statement only accepts ANALYZE") + }; + let table_name = analyze + .table_name + .as_ref() + .ok_or_else(|| DatabaseError::UnsupportedStmt("ANALYZE requires table name".to_string()))?; + db.analyze(table_name.to_string()) } diff --git a/tpcc/Cargo.toml b/tpcc/Cargo.toml index 4dd6e9e7..9ea476eb 100644 --- a/tpcc/Cargo.toml +++ b/tpcc/Cargo.toml @@ -9,12 +9,11 @@ pprof = ["dep:pprof"] [dependencies] clap = { version = "4", features = ["derive"] } chrono = { version = "0.4" } -kite_sql = { path = "..", package = "kite_sql", features = ["rocksdb", "lmdb"] } +kite_sql = { path = "..", package = "kite_sql", features = ["rocksdb", "lmdb", "decimal"] } indicatif = { version = "0.17" } ordered-float = { version = "4" } rand = { version = "0.8" } rust_decimal = { version = "1" } -thiserror = { version = "1" } sqlite = { version = "0.34" } [target.'cfg(unix)'.dependencies] diff --git a/tpcc/README.md b/tpcc/README.md index d5b46e90..480f5dde 100644 --- a/tpcc/README.md +++ b/tpcc/README.md @@ -29,25 +29,26 @@ Local 720-second comparison on the machine above: | Backend | TpmC | New-Order p90 | Payment p90 | Order-Status p90 | Delivery p90 | Stock-Level p90 | | --- | ---: | ---: | ---: | ---: | ---: | ---: | -| KiteSQL LMDB | 68394 | 0.001s | 0.001s | 0.001s | 0.002s | 0.001s | -| KiteSQL RocksDB | 30387 | 0.001s | 0.001s | 0.001s | 0.015s | 0.002s | -| SQLite balanced | 41690 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | -| SQLite practical | 38861 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | +| KiteSQL LMDB | 61723 | 0.001s | 0.001s | 0.001s | 0.002s | 0.001s | +| KiteSQL RocksDB | 30446 | 0.001s | 0.001s | 0.001s | 0.016s | 0.002s | +| SQLite balanced | 42989 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | +| SQLite practical | 42276 | 0.001s | 0.001s | 0.001s | 0.001s | 0.001s | - All rows are from fresh 720-second reruns with `--num-ware 1` and the default `--max-retry 5`. - SQLite rows use the `balanced` and `practical` profiles respectively. +- Raw logs for this run were generated under `tpcc/results/2026-06-14_720-full-current/`. -### SQLite practical +### KiteSQL LMDB ```shell Transaction Summary (elapsed 720.0s) +--------------+---------+------+---------+-------+ | Transaction | Success | Late | Failure | Total | +--------------+---------+------+---------+-------+ -| New-Order | 466339 | 0 | 4794 | 471133 | -| Payment | 466316 | 0 | 0 | 466316 | -| Order-Status | 46632 | 0 | 0 | 46632 | -| Delivery | 46631 | 0 | 0 | 46631 | -| Stock-Level | 46632 | 0 | 0 | 46632 | +| New-Order | 740676 | 0 | 7570 | 748246 | +| Payment | 740652 | 0 | 0 | 740652 | +| Order-Status | 74066 | 0 | 0 | 74066 | +| Delivery | 74065 | 0 | 0 | 74065 | +| Stock-Level | 74066 | 0 | 0 | 74066 | +--------------+---------+------+---------+-------+ (all must be [OK]) [transaction percentage] @@ -67,49 +68,53 @@ Transaction Summary (elapsed 720.0s) 1.New-Order -0.001, 465092 -0.002, 1232 -0.003, 15 +0.001, 740407 +0.002, 269 2.Payment -0.001, 466316 +0.001, 740652 3.Order-Status -0.001, 46632 +0.001, 72021 +0.002, 1961 +0.003, 81 +0.004, 3 4.Delivery -0.001, 46388 -0.002, 240 -0.003, 3 +0.002, 73846 +0.003, 179 +0.004, 40 5.Stock-Level -0.001, 46632 +0.001, 73870 +0.002, 192 +0.003, 4 <90th Percentile RT (MaxRT)> New-Order : 0.001 (0.002) Payment : 0.001 (0.001) -Order-Status : 0.001 (0.000) - Delivery : 0.001 (0.002) - Stock-Level : 0.001 (0.000) +Order-Status : 0.001 (0.004) + Delivery : 0.002 (0.003) + Stock-Level : 0.001 (0.002) -38861 Tpmc +61723 Tpmc ``` -### KiteSQL LMDB +### KiteSQL RocksDB ```shell Transaction Summary (elapsed 720.0s) +--------------+---------+------+---------+-------+ | Transaction | Success | Late | Failure | Total | +--------------+---------+------+---------+-------+ -| New-Order | 820734 | 0 | 8539 | 829273 | -| Payment | 820710 | 0 | 0 | 820710 | -| Order-Status | 82071 | 0 | 0 | 82071 | -| Delivery | 82071 | 0 | 0 | 82071 | -| Stock-Level | 82071 | 0 | 0 | 82071 | +| New-Order | 365348 | 0 | 3642 | 368990 | +| Payment | 365329 | 0 | 0 | 365329 | +| Order-Status | 36532 | 0 | 0 | 36532 | +| Delivery | 36533 | 0 | 0 | 36533 | +| Stock-Level | 36532 | 0 | 0 | 36532 | +--------------+---------+------+---------+-------+ (all must be [OK]) [transaction percentage] @@ -129,58 +134,81 @@ Transaction Summary (elapsed 720.0s) 1.New-Order -0.001, 820617 -0.002, 101 -0.003, 15 -0.005, 1 +0.001, 365260 +0.002, 85 +0.003, 3 2.Payment -0.001, 820710 +0.001, 365325 +0.002, 1 +0.003, 1 3.Order-Status -0.001, 79893 -0.002, 2053 -0.003, 125 +0.001, 33395 +0.002, 2506 +0.003, 491 +0.004, 133 +0.005, 7 4.Delivery -0.002, 82014 -0.003, 41 -0.004, 11 -0.005, 2 -0.006, 2 -0.007, 1 +0.002, 670 +0.003, 3081 +0.004, 3292 +0.005, 1845 +0.006, 2227 +0.007, 2495 +0.008, 2465 +0.009, 2434 +0.010, 2546 +0.011, 2394 +0.012, 2231 +0.013, 2333 +0.014, 2207 +0.015, 2564 +0.016, 2219 +0.017, 1165 +0.018, 345 +0.019, 12 +0.020, 3 +0.021, 1 +0.022, 1 +0.023, 1 +0.024, 1 +0.026, 1 5.Stock-Level -0.001, 82041 -0.002, 27 -0.003, 2 -0.004, 1 +0.001, 19405 +0.002, 13755 +0.003, 3316 +0.004, 53 +0.005, 1 +0.007, 1 <90th Percentile RT (MaxRT)> - New-Order : 0.001 (0.005) - Payment : 0.001 (0.001) -Order-Status : 0.001 (0.002) - Delivery : 0.002 (0.006) - Stock-Level : 0.001 (0.003) + New-Order : 0.001 (0.003) + Payment : 0.001 (0.007) +Order-Status : 0.001 (0.005) + Delivery : 0.016 (0.025) + Stock-Level : 0.002 (0.011) -68394 Tpmc +30446 Tpmc ``` -### KiteSQL RocksDB +### SQLite balanced ```shell Transaction Summary (elapsed 720.0s) +--------------+---------+------+---------+-------+ | Transaction | Success | Late | Failure | Total | +--------------+---------+------+---------+-------+ -| New-Order | 364643 | 0 | 3680 | 368323 | -| Payment | 364620 | 0 | 0 | 364620 | -| Order-Status | 36462 | 0 | 0 | 36462 | -| Delivery | 36462 | 0 | 0 | 36462 | -| Stock-Level | 36462 | 0 | 0 | 36462 | +| New-Order | 515898 | 0 | 5323 | 521221 | +| Payment | 515873 | 0 | 0 | 515873 | +| Order-Status | 51587 | 0 | 0 | 51587 | +| Delivery | 51587 | 0 | 0 | 51587 | +| Stock-Level | 51588 | 0 | 0 | 51588 | +--------------+---------+------+---------+-------+ (all must be [OK]) [transaction percentage] @@ -200,78 +228,49 @@ Transaction Summary (elapsed 720.0s) 1.New-Order -0.001, 364318 -0.002, 317 -0.003, 6 +0.001, 514986 +0.002, 886 +0.003, 26 2.Payment -0.001, 364591 -0.002, 25 +0.001, 515873 3.Order-Status -0.001, 33075 -0.002, 2697 -0.003, 507 -0.004, 139 -0.005, 17 +0.001, 51587 4.Delivery -0.002, 267 -0.003, 2720 -0.004, 3784 -0.005, 2306 -0.006, 2973 -0.007, 3023 -0.008, 2369 -0.009, 2423 -0.010, 2347 -0.011, 2531 -0.012, 2473 -0.013, 2253 -0.014, 2295 -0.015, 2402 -0.016, 1429 -0.017, 740 -0.018, 102 -0.019, 16 -0.020, 4 -0.021, 2 -0.022, 2 -0.023, 1 +0.001, 51420 +0.002, 165 +0.003, 2 5.Stock-Level -0.001, 20624 -0.002, 12335 -0.003, 3423 -0.004, 74 -0.005, 3 -0.006, 1 +0.001, 51588 <90th Percentile RT (MaxRT)> - New-Order : 0.001 (0.008) - Payment : 0.001 (0.009) -Order-Status : 0.001 (0.006) - Delivery : 0.015 (0.022) - Stock-Level : 0.002 (0.011) + New-Order : 0.001 (0.003) + Payment : 0.001 (0.001) +Order-Status : 0.001 (0.001) + Delivery : 0.001 (0.003) + Stock-Level : 0.001 (0.000) -30387 Tpmc +42989 Tpmc ``` -### SQLite balanced +### SQLite practical ```shell Transaction Summary (elapsed 720.0s) +--------------+---------+------+---------+-------+ | Transaction | Success | Late | Failure | Total | +--------------+---------+------+---------+-------+ -| New-Order | 500279 | 0 | 5039 | 505318 | -| Payment | 500257 | 0 | 0 | 500257 | -| Order-Status | 50025 | 0 | 0 | 50025 | -| Delivery | 50026 | 0 | 0 | 50026 | -| Stock-Level | 50026 | 0 | 0 | 50026 | +| New-Order | 507311 | 0 | 5123 | 512434 | +| Payment | 507288 | 0 | 0 | 507288 | +| Order-Status | 50728 | 0 | 0 | 50728 | +| Delivery | 50729 | 0 | 0 | 50729 | +| Stock-Level | 50729 | 0 | 0 | 50729 | +--------------+---------+------+---------+-------+ (all must be [OK]) [transaction percentage] @@ -291,40 +290,37 @@ Transaction Summary (elapsed 720.0s) 1.New-Order -0.001, 497040 -0.002, 3214 -0.003, 20 -0.004, 4 -0.005, 1 +0.001, 506391 +0.002, 898 +0.003, 22 2.Payment -0.001, 500250 -0.002, 5 -0.003, 2 +0.001, 507287 +0.002, 1 3.Order-Status -0.001, 50025 +0.001, 50728 4.Delivery -0.001, 49409 -0.002, 612 -0.003, 5 +0.001, 50566 +0.002, 160 +0.003, 3 5.Stock-Level -0.001, 50026 +0.001, 50729 <90th Percentile RT (MaxRT)> - New-Order : 0.001 (0.004) - Payment : 0.001 (0.002) + New-Order : 0.001 (0.002) + Payment : 0.001 (0.001) Order-Status : 0.001 (0.001) - Delivery : 0.001 (0.002) + Delivery : 0.001 (0.003) Stock-Level : 0.001 (0.000) -41690 Tpmc +42276 Tpmc ``` ## Refer to diff --git a/tpcc/src/backend/dual.rs b/tpcc/src/backend/dual.rs index cc6bd3f2..8d13e7d7 100644 --- a/tpcc/src/backend/dual.rs +++ b/tpcc/src/backend/dual.rs @@ -63,7 +63,7 @@ impl BackendControl for DualBackend { } impl SimpleExecutor for DualBackend { - fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { + fn execute_batch(&mut self, sql: &str) -> Result<(), TpccError> { self.kitesql.execute_batch(sql)?; if let Some(stmt) = normalize_sqlite_sql(sql) { self.sqlite.execute_batch(&stmt)?; diff --git a/tpcc/src/backend/kitesql_lmdb.rs b/tpcc/src/backend/kitesql_lmdb.rs index 46642ed5..d551e12b 100644 --- a/tpcc/src/backend/kitesql_lmdb.rs +++ b/tpcc/src/backend/kitesql_lmdb.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::kitesql_rocksdb::KiteSqlTxnResult; +use super::kitesql_rocksdb::{execute_kitesql_batch, KiteSqlTxnResult}; use super::{ BackendControl, BackendTransaction, DbParam, PreparedStatement, SimpleExecutor, StatementSpec, }; @@ -79,9 +79,8 @@ impl BackendControl for KiteSqlLmdbBackend { } impl SimpleExecutor for KiteSqlLmdbBackend { - fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { - self.database.run(sql)?.done()?; - Ok(()) + fn execute_batch(&mut self, sql: &str) -> Result<(), TpccError> { + execute_kitesql_batch(&mut self.database, sql) } } diff --git a/tpcc/src/backend/kitesql_rocksdb.rs b/tpcc/src/backend/kitesql_rocksdb.rs index 538fc286..db1ab87d 100644 --- a/tpcc/src/backend/kitesql_rocksdb.rs +++ b/tpcc/src/backend/kitesql_rocksdb.rs @@ -16,7 +16,10 @@ use super::{ BackendControl, BackendTransaction, DbParam, PreparedStatement, SimpleExecutor, StatementSpec, }; use crate::TpccError; -use kite_sql::db::{prepare, DBTransaction, DataBaseBuilder, Database, TransactionIter}; +use kite_sql::binder::{command_type, CommandType}; +use kite_sql::db::{ + prepare, prepare_all, DBTransaction, DataBaseBuilder, Database, Statement, TransactionIter, +}; use kite_sql::storage::rocksdb::{OptimisticRocksStorage, RocksStorage}; use kite_sql::storage::{Storage, Transaction}; use kite_sql::types::tuple::Tuple; @@ -101,10 +104,57 @@ impl BackendControl for KiteSqlRocksBackend { } impl SimpleExecutor for KiteSqlRocksBackend { - fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { - self.database.run(sql)?.done()?; - Ok(()) + fn execute_batch(&mut self, sql: &str) -> Result<(), TpccError> { + execute_kitesql_batch(&mut self.database, sql) + } +} + +pub(crate) fn execute_kitesql_batch( + database: &mut Database, + sql: &str, +) -> Result<(), TpccError> { + let statements = prepare_all(sql)?; + let all_ddl = statements.iter().try_fold(true, |all_ddl, statement| { + Ok::<_, kite_sql::errors::DatabaseError>( + all_ddl && matches!(command_type(statement)?, CommandType::DDL), + ) + })?; + if all_ddl { + database.ddl(sql)?; + return Ok(()); + } + + if statements + .iter() + .all(|statement| matches!(statement, Statement::Analyze(_))) + { + for statement in statements { + let Statement::Analyze(analyze) = statement else { + unreachable!("checked above") + }; + let table_name = analyze.table_name.ok_or_else(|| { + kite_sql::errors::DatabaseError::UnsupportedStmt( + "ANALYZE requires table name".to_string(), + ) + })?; + let table_name = table_name.to_string(); + let start = std::time::Instant::now(); + println!("[KiteSQL Analyze: {table_name}] started"); + database.analyze(&table_name)?; + println!( + "[KiteSQL Analyze: {table_name}] completed in {:.2}s", + start.elapsed().as_secs_f64() + ); + } + return Ok(()); + } + + let mut transaction = database.new_transaction()?; + for statement in statements { + transaction.execute(statement, &[])?.done()?; } + transaction.commit()?; + Ok(()) } pub struct KiteSqlRocksTransaction<'a, S: Storage> { diff --git a/tpcc/src/backend/mod.rs b/tpcc/src/backend/mod.rs index 12c88d2e..db38a639 100644 --- a/tpcc/src/backend/mod.rs +++ b/tpcc/src/backend/mod.rs @@ -25,7 +25,7 @@ use kite_sql::types::value::DataValue; pub type DbParam = (&'static str, DataValue); pub trait SimpleExecutor { - fn execute_batch(&self, sql: &str) -> Result<(), TpccError>; + fn execute_batch(&mut self, sql: &str) -> Result<(), TpccError>; } pub trait BackendControl: SimpleExecutor { diff --git a/tpcc/src/backend/sqlite.rs b/tpcc/src/backend/sqlite.rs index 652df43d..8ffc4759 100644 --- a/tpcc/src/backend/sqlite.rs +++ b/tpcc/src/backend/sqlite.rs @@ -92,7 +92,7 @@ impl BackendControl for SqliteBackend { } impl SimpleExecutor for SqliteBackend { - fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { + fn execute_batch(&mut self, sql: &str) -> Result<(), TpccError> { if let Some(stmt) = normalize_sqlite_sql(sql) { self.connection.execute(&stmt)?; } diff --git a/tpcc/src/load.rs b/tpcc/src/load.rs index e444e656..c53141ba 100644 --- a/tpcc/src/load.rs +++ b/tpcc/src/load.rs @@ -51,13 +51,13 @@ fn log_phase(task: &str, current: usize, total: usize, context: &str) { } struct SqlBatch<'a, E> { - exec: &'a E, + exec: &'a mut E, sql: String, pending: usize, } impl<'a, E: SimpleExecutor> SqlBatch<'a, E> { - fn new(exec: &'a E) -> Self { + fn new(exec: &'a mut E) -> Self { Self { exec, sql: String::new(), @@ -149,7 +149,10 @@ impl Load { /// i_data varchar(50) /// /// primary key (i_id) - pub fn load_items(rng: &mut ThreadRng, exec: &impl SimpleExecutor) -> Result<(), TpccError> { + pub fn load_items( + rng: &mut ThreadRng, + exec: &mut impl SimpleExecutor, + ) -> Result<(), TpccError> { exec.execute_batch("drop table if exists item;")?; exec.execute_batch( "create table item ( @@ -213,7 +216,7 @@ impl Load { /// primary key (w_id) pub fn load_warehouses( rng: &mut ThreadRng, - exec: &impl SimpleExecutor, + exec: &mut impl SimpleExecutor, num_ware: usize, ) -> Result<(), TpccError> { exec.execute_batch("drop table if exists warehouse;")?; @@ -314,7 +317,7 @@ impl Load { pub fn load_custs( rng: &mut ThreadRng, - exec: &impl SimpleExecutor, + exec: &mut impl SimpleExecutor, num_ware: usize, ) -> Result<(), TpccError> { exec.execute_batch("drop table if exists customer;")?; @@ -378,7 +381,7 @@ impl Load { pub fn load_ord( rng: &mut ThreadRng, - exec: &impl SimpleExecutor, + exec: &mut impl SimpleExecutor, num_ware: usize, ) -> Result<(), TpccError> { exec.execute_batch("drop table if exists orders;")?; @@ -466,7 +469,7 @@ impl Load { /// primary key(s_w_id, s_i_id) pub fn stock( rng: &mut ThreadRng, - exec: &impl SimpleExecutor, + exec: &mut impl SimpleExecutor, w_id: usize, ) -> Result<(), TpccError> { let pb = ProgressBar::new(MAX_ITEMS as u64); @@ -563,7 +566,7 @@ impl Load { /// primary key (d_w_id, d_id) pub fn district( rng: &mut ThreadRng, - exec: &impl SimpleExecutor, + exec: &mut impl SimpleExecutor, w_id: usize, ) -> Result<(), TpccError> { let pb = ProgressBar::new(DIST_PER_WARE as u64); @@ -652,7 +655,7 @@ impl Load { /// h_data varchar(24) pub fn load_customers( rng: &mut ThreadRng, - exec: &impl SimpleExecutor, + exec: &mut impl SimpleExecutor, d_id: usize, w_id: usize, ) -> Result<(), TpccError> { @@ -783,7 +786,7 @@ impl Load { /// primary key(ol_w_id, ol_d_id, ol_o_id, ol_number) pub fn load_orders( rng: &mut ThreadRng, - exec: &impl SimpleExecutor, + exec: &mut impl SimpleExecutor, d_id: usize, w_id: usize, ) -> Result<(), TpccError> { @@ -869,7 +872,7 @@ mod tests { } impl SimpleExecutor for RecordingExecutor { - fn execute_batch(&self, sql: &str) -> Result<(), TpccError> { + fn execute_batch(&mut self, sql: &str) -> Result<(), TpccError> { self.calls.borrow_mut().push(sql.to_string()); Ok(()) } @@ -877,8 +880,8 @@ mod tests { #[test] fn sql_batch_groups_statements() { - let exec = RecordingExecutor::default(); - let mut batch = SqlBatch::new(&exec); + let mut exec = RecordingExecutor::default(); + let mut batch = SqlBatch::new(&mut exec); batch.push("insert 1").unwrap(); batch.push("insert 2").unwrap(); @@ -892,8 +895,8 @@ mod tests { #[test] fn sql_batch_flushes_at_batch_size() { - let exec = RecordingExecutor::default(); - let mut batch = SqlBatch::new(&exec); + let mut exec = RecordingExecutor::default(); + let mut batch = SqlBatch::new(&mut exec); for i in 0..=LOAD_BATCH_SIZE { batch.push(&format!("insert {i}")).unwrap(); diff --git a/tpcc/src/main.rs b/tpcc/src/main.rs index 4822f32a..ac040b8d 100644 --- a/tpcc/src/main.rs +++ b/tpcc/src/main.rs @@ -34,6 +34,8 @@ use kite_sql::errors::DatabaseError; use pprof::ProfilerGuard; use rand::prelude::ThreadRng; use rand::Rng; +use std::error::Error; +use std::fmt; use std::fs; use std::path::Path; #[cfg(any(feature = "pprof", test))] @@ -140,28 +142,28 @@ fn main() -> Result<(), TpccError> { match args.backend { BackendKind::KitesqlRocksdb => { reset_db_path(Path::new(&args.path))?; - let backend = KiteSqlRocksDbBackend::new(&args.path, args.rocksdb_stats)?; - run_tpcc(&backend, &args, &mut rng) + let mut backend = KiteSqlRocksDbBackend::new(&args.path, args.rocksdb_stats)?; + run_tpcc(&mut backend, &args, &mut rng) } BackendKind::KitesqlOptimisticRocksdb => { reset_db_path(Path::new(&args.path))?; - let backend = KiteSqlOptimisticRocksDbBackend::new(&args.path, args.rocksdb_stats)?; - run_tpcc(&backend, &args, &mut rng) + let mut backend = KiteSqlOptimisticRocksDbBackend::new(&args.path, args.rocksdb_stats)?; + run_tpcc(&mut backend, &args, &mut rng) } BackendKind::KitesqlLmdb => { reset_db_path(Path::new(&args.path))?; - let backend = KiteSqlLmdbBackend::new(&args.path)?; - run_tpcc(&backend, &args, &mut rng) + let mut backend = KiteSqlLmdbBackend::new(&args.path)?; + run_tpcc(&mut backend, &args, &mut rng) } BackendKind::Sqlite => { reset_db_path(Path::new(&args.path))?; - let backend = SqliteBackend::new(&args.path, args.sqlite_profile)?; - run_tpcc(&backend, &args, &mut rng) + let mut backend = SqliteBackend::new(&args.path, args.sqlite_profile)?; + run_tpcc(&mut backend, &args, &mut rng) } BackendKind::Dual => { reset_db_path(Path::new(&args.path))?; - let backend = DualBackend::new(&args.path, args.rocksdb_stats)?; - run_tpcc(&backend, &args, &mut rng) + let mut backend = DualBackend::new(&args.path, args.rocksdb_stats)?; + run_tpcc(&mut backend, &args, &mut rng) } } } @@ -181,7 +183,7 @@ fn reset_db_path(path: &Path) -> Result<(), TpccError> { } fn run_tpcc( - backend: &B, + backend: &mut B, args: &Args, rng: &mut ThreadRng, ) -> Result<(), TpccError> { @@ -728,59 +730,101 @@ fn other_ware(rng: &mut ThreadRng, home_ware: usize, num_ware: usize) -> usize { } } -#[derive(thiserror::Error, Debug)] +#[derive(Debug)] pub enum TpccError { - #[error("kite_sql: {0}")] - Database( - #[source] - #[from] - DatabaseError, - ), - #[error("sqlite: {0}")] - Sqlite( - #[source] - #[from] - sqlite::Error, - ), - #[error("decimal parse error: {0}")] - Decimal( - #[source] - #[from] - rust_decimal::Error, - ), - #[error("datetime parse error: {0}")] - Chrono( - #[source] - #[from] - chrono::ParseError, - ), - #[error("io error: {0}")] - Io( - #[source] - #[from] - std::io::Error, - ), - #[error("tuples is empty")] + Database(DatabaseError), + Sqlite(sqlite::Error), + Decimal(rust_decimal::Error), + Chrono(chrono::ParseError), + Io(std::io::Error), EmptyTuples, - #[error("maximum retries reached")] MaxRetry, - #[error("invalid backend usage")] InvalidBackend, - #[error("invalid parameter name")] InvalidParameter, - #[error("invalid datetime value")] InvalidDateTime, - #[error("backend mismatch: {0}")] BackendMismatch(String), - #[error("profile error: {0}")] Profile(String), } +impl fmt::Display for TpccError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Database(err) => write!(f, "kite_sql: {err}"), + Self::Sqlite(err) => write!(f, "sqlite: {err}"), + Self::Decimal(err) => write!(f, "decimal parse error: {err}"), + Self::Chrono(err) => write!(f, "datetime parse error: {err}"), + Self::Io(err) => write!(f, "io error: {err}"), + Self::EmptyTuples => f.write_str("tuples is empty"), + Self::MaxRetry => f.write_str("maximum retries reached"), + Self::InvalidBackend => f.write_str("invalid backend usage"), + Self::InvalidParameter => f.write_str("invalid parameter name"), + Self::InvalidDateTime => f.write_str("invalid datetime value"), + Self::BackendMismatch(value) => write!(f, "backend mismatch: {value}"), + Self::Profile(value) => write!(f, "profile error: {value}"), + } + } +} + +impl Error for TpccError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::Database(err) => Some(err), + Self::Sqlite(err) => Some(err), + Self::Decimal(err) => Some(err), + Self::Chrono(err) => Some(err), + Self::Io(err) => Some(err), + _ => None, + } + } +} + +macro_rules! impl_from_tpcc_error { + ($source:ty, $variant:ident) => { + impl From<$source> for TpccError { + fn from(value: $source) -> Self { + Self::$variant(value) + } + } + }; +} + +impl_from_tpcc_error!(DatabaseError, Database); +impl_from_tpcc_error!(sqlite::Error, Sqlite); +impl_from_tpcc_error!(rust_decimal::Error, Decimal); +impl_from_tpcc_error!(chrono::ParseError, Chrono); +impl_from_tpcc_error!(std::io::Error, Io); + #[ignore] #[test] fn explain_tpcc() -> Result<(), DatabaseError> { - use kite_sql::db::DataBaseBuilder; - use kite_sql::types::tuple::create_table; + use kite_sql::db::{DataBaseBuilder, ResultIter}; + + fn create_table(mut iter: I) -> Result { + let mut output = iter.schema(|schema| { + schema + .iter() + .map(|column| column.full_name().to_string()) + .collect::>() + .join("\t") + }); + if !output.is_empty() { + output.push('\n'); + } + for tuple in iter.by_ref() { + let tuple = tuple?; + output.push_str( + &tuple + .values + .iter() + .map(|value| value.to_string()) + .collect::>() + .join("\t"), + ); + output.push('\n'); + } + iter.done()?; + Ok(output) + } let database = DataBaseBuilder::path(tpcc_db_path()).build_lmdb()?; let mut tx = database.new_transaction()?;