diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py b/paimon-python/pypaimon/manifest/manifest_file_manager.py index 0ed50918253c..308dc13a737d 100644 --- a/paimon-python/pypaimon/manifest/manifest_file_manager.py +++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py @@ -17,7 +17,7 @@ ################################################################################ from concurrent.futures import ThreadPoolExecutor from io import BytesIO -from typing import List +from typing import Callable, List, Optional import fastavro @@ -48,10 +48,13 @@ def __init__(self, table): self.trimmed_primary_keys_fields = self.table.trimmed_primary_keys_fields def read_entries_parallel(self, manifest_files: List[ManifestFileMeta], manifest_entry_filter=None, - drop_stats=True, max_workers=8) -> List[ManifestEntry]: + drop_stats=True, max_workers=8, + early_entry_filter: Optional[Callable[[int, int], bool]] = None + ) -> List[ManifestEntry]: def _process_single_manifest(manifest_file: ManifestFileMeta) -> List[ManifestEntry]: - return self.read(manifest_file.file_name, manifest_entry_filter, drop_stats) + return self.read(manifest_file.file_name, manifest_entry_filter, drop_stats, + early_entry_filter=early_entry_filter) def _entry_identifier(e: ManifestEntry) -> tuple: return ( @@ -81,7 +84,19 @@ def _entry_identifier(e: ManifestEntry) -> tuple: ] return final_entries - def read(self, manifest_file_name: str, manifest_entry_filter=None, drop_stats=True) -> List[ManifestEntry]: + def read(self, manifest_file_name: str, manifest_entry_filter=None, drop_stats=True, + early_entry_filter: Optional[Callable[[int, int], bool]] = None + ) -> List[ManifestEntry]: + """ + early_entry_filter: optional ``(bucket, total_buckets) -> bool`` + called immediately after the avro record is parsed. Mirrors + Java ``BucketFilter`` applied at the InternalRow stage in + ``ManifestEntryCache``: when it returns False, the entry's + ``_FILE`` block / partition / stats are never deserialized. + Caller is responsible for soundness (any non-pruning rule must + return True). The full ``manifest_entry_filter`` still runs on + the survivors. + """ manifest_file_path = f"{self.manifest_path}/{manifest_file_name}" entries = [] @@ -91,6 +106,9 @@ def read(self, manifest_file_name: str, manifest_entry_filter=None, drop_stats=T reader = fastavro.reader(buffer) for record in reader: + if early_entry_filter is not None and not early_entry_filter( + record['_BUCKET'], record['_TOTAL_BUCKETS']): + continue file_dict = dict(record['_FILE']) key_dict = dict(file_dict['_KEY_STATS']) key_stats = SimpleStats( diff --git a/paimon-python/pypaimon/read/scanner/bucket_select_converter.py b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py new file mode 100644 index 000000000000..e0c3b6bfa409 --- /dev/null +++ b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py @@ -0,0 +1,297 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +""" +Predicate-driven bucket pruning for HASH_FIXED tables. + +Mirrors Java's ``org.apache.paimon.operation.BucketSelectConverter``: +walk the predicate, isolate AND clauses that constrain bucket-key fields +with Equal/In, take the cartesian product of literal values, hash each +combination using the writer's hash routine, and produce the set of +buckets the query can possibly hit. All other entries are safely dropped. + +Hard correctness contract: the bucket set this returns is a *superset* of +the buckets that contain any matching rows. False-positive (over-keep) +allowed; false-negative (drop a bucket that has matching rows) MUST never +happen — that would be silent data loss. + +The hashing routine reuses ``RowKeyExtractor._hash_bytes_by_words`` / +``_bucket_from_hash`` from ``pypaimon.write.row_key_extractor`` — the same +code path the writer uses to assign rows to buckets. Reusing it (rather +than copying) is what guarantees read/write hash agreement in the face of +future routine changes. + +Conservative scope (deliberately narrower than Java's general flexibility): + + * Only HASH_FIXED tables (caller's responsibility to gate; this module + does not look at the bucket mode itself). + * All bucket-key fields must be constrained, with Equal or In, in a + single AND-of-OR-of-literals shape. If any bucket-key column is + unconstrained, return None — the caller must scan all buckets. + * Repeated constraints on the same bucket-key column under top-level + AND (e.g. ``id IN (1,2,3) AND id IN (2,3,4)``) intersect their + literal sets (mirrors Java ``BucketSelector.retainAll``). An empty + intersection means the predicate is unsatisfiable, and we return + None. + * Total cartesian product capped at MAX_VALUES (1000), again matching + Java; above that, fall back to a full scan. + +Returns a callable ``selector(bucket: int, total_buckets: int) -> bool``. +The callable is cached per ``total_buckets`` to handle the rare case +where bucket count varies across snapshots (rescale). + +TODO: per-partition predicate pre-evaluation. + + Predicates of the form ``(part='a' AND bk IN (1,2)) OR (part='b' AND bk + IN (3,4))`` currently fall through to "no pruning" because the top-level + OR mixes partition and bucket-key constraints. Java simplifies the + predicate per concrete partition value first (replacing partition + leaves with literal true/false and folding AND/OR), so each partition + gets a tighter bucket-key predicate and the corresponding bucket set. + + Implementing this here would need three pieces: + + * a Predicate-replace walker that substitutes a partition's actual + values into partition-column leaves (mirrors Java's + ``paimon-common/.../predicate/PartitionValuePredicateVisitor.java``). + * lifting ``_Selector`` to key its cache by + ``(partition, total_buckets)`` instead of just ``total_buckets``. + * threading the partition value into the early manifest filter + ``FileScanner._build_early_bucket_filter`` (currently sees only + ``(bucket, total_buckets)``). +""" + +from itertools import product +from typing import Any, Callable, Dict, FrozenSet, List, Optional + +from pypaimon.common.predicate import Predicate +from pypaimon.schema.data_types import DataField +from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer +from pypaimon.table.row.internal_row import RowKind +from pypaimon.write.row_key_extractor import (_bucket_from_hash, + _hash_bytes_by_words) + +MAX_VALUES = 1000 + +# Bucket-key column types where the Python serializer is not byte-aligned +# with the writer's logical value, or with Java's ``BinaryRow`` byte layout. +# A divergent hash is silent data loss (false-negative), so the selector +# refuses to build at all when a bucket-key field has one of these types. +# +# Two reasons something gets blacklisted: +# +# 1. Locale / precision drift between writer and reader for equal logical +# values (DECIMAL via float-vs-Decimal, TIMESTAMP via naive datetime +# timezone interpretation). +# 2. Composite / nested types whose ``GenericRowSerializer`` byte layout +# hasn't been cross-validated against Java's ``BinaryRow`` (ARRAY, +# MAP, ROW, MULTISET, VARIANT, BLOB). Until that validation lands, +# treating them as safe risks a hash divergence. +_UNSAFE_BUCKET_KEY_TYPES = ( + 'DECIMAL', + 'TIMESTAMP', + 'ARRAY', + 'MAP', + 'ROW', + 'MULTISET', + 'VARIANT', + 'BLOB', +) + + +def _has_unsafe_bucket_key_type(bucket_key_fields: List[DataField]) -> bool: + for f in bucket_key_fields: + type_name = getattr(getattr(f, 'type', None), 'type', '') + if not type_name: + continue + head = type_name.split('(')[0].strip().upper() + if any(head.startswith(prefix) for prefix in _UNSAFE_BUCKET_KEY_TYPES): + return True + return False + + +def _split_and(p: Predicate) -> List[Predicate]: + if p.method == 'and': + out: List[Predicate] = [] + for child in (p.literals or []): + out.extend(_split_and(child)) + return out + return [p] + + +def _split_or(p: Predicate) -> List[Predicate]: + if p.method == 'or': + out: List[Predicate] = [] + for child in (p.literals or []): + out.extend(_split_or(child)) + return out + return [p] + + +def _extract_or_clause(or_pred: Predicate, + bk_name_to_slot: Dict[str, int]) -> Optional[List[Any]]: + """For one AND-child predicate, return either: + * ``[slot_index, [literal, ...]]`` — the OR/leaf is a pure + Equal-or-In list on a single bucket-key field; or + * ``None`` — the clause is not a bucket-key constraint we can + safely use; the caller skips it. + + All disjuncts must hit the same bucket-key column. Mixed columns or + non-Equal/In operators disqualify the entire AND clause. + """ + slot: Optional[int] = None + values: List[Any] = [] + for clause in _split_or(or_pred): + if clause.method not in ('equal', 'in'): + return None + if clause.field is None or clause.field not in bk_name_to_slot: + return None + this_slot = bk_name_to_slot[clause.field] + if slot is not None and slot != this_slot: + return None + slot = this_slot + for lit in (clause.literals or []): + # Java filters nulls; null literals are degenerate (NULL = NULL + # is UNKNOWN in SQL). Producing zero values for a slot will + # cascade through the cartesian product to "match nothing", + # which is the same observable behaviour as Java. + if lit is None: + continue + values.append(lit) + return None if slot is None else [slot, values] + + +class _Selector: + """Callable bucket filter, lazy + cached per ``total_buckets``.""" + + __slots__ = ('_combinations', '_bucket_key_fields', '_cache') + + def __init__(self, combinations: List[List[Any]], + bucket_key_fields: List[DataField]): + self._combinations = combinations + self._bucket_key_fields = bucket_key_fields + self._cache: Dict[int, FrozenSet[int]] = {} + + def __call__(self, bucket: int, total_buckets: int) -> bool: + # ``total_buckets <= 0`` shows up for postpone / legacy / special + # entries and must NOT be pruned: returning False here would drop + # rows the writer hashed under a different convention. Fail open. + if total_buckets <= 0: + return True + try: + return bucket in self._compute(total_buckets) + except Exception: + # Fail open on any hashing/serialization error (e.g. a literal + # type that doesn't match the bucket-key column's atomic type: + # ``pb.equal('id_bigint', 'foo')`` — GenericRowSerializer raises + # struct.error trying to pack the string as int64). Crashing + # the entire scan here would be worse than skipping pruning; + # the soundness contract still forbids false-negatives. + return True + + def _compute(self, total_buckets: int) -> FrozenSet[int]: + cached = self._cache.get(total_buckets) + if cached is not None: + return cached + result = set() + for combo in self._combinations: + row = GenericRow(list(combo), self._bucket_key_fields, + RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + # Skip the 4-byte length prefix — matches the writer's hash + # input exactly (see RowKeyExtractor._binary_row_hash_code). + h = _hash_bytes_by_words(serialized[4:]) + result.add(_bucket_from_hash(h, total_buckets)) + frozen = frozenset(result) + self._cache[total_buckets] = frozen + return frozen + + @property + def bucket_combinations(self) -> int: + """Number of (bucket-key) combinations used to compute the filter. + Exposed for tests / observability.""" + return len(self._combinations) + + +def create_bucket_selector( + predicate: Optional[Predicate], + bucket_key_fields: List[DataField]) -> Optional[Callable[[int, int], bool]]: + """Try to derive a bucket selector from ``predicate`` constrained to + ``bucket_key_fields``. + + Returns: + A callable ``(bucket, total_buckets) -> bool`` if the predicate + pins down all bucket keys to a finite Equal/In set; otherwise None + (caller must NOT prune by bucket). + """ + if predicate is None or not bucket_key_fields: + return None + + # See ``_UNSAFE_BUCKET_KEY_TYPES``: refuse pruning when the bucket-key + # column types are prone to writer/reader byte-level disagreement on + # equal logical values. Fail open rather than risk false-negatives. + if _has_unsafe_bucket_key_type(bucket_key_fields): + return None + + bk_name_to_slot: Dict[str, int] = { + f.name: i for i, f in enumerate(bucket_key_fields) + } + n_slots = len(bucket_key_fields) + slot_values: List[Optional[List[Any]]] = [None] * n_slots + + for and_child in _split_and(predicate): + extracted = _extract_or_clause(and_child, bk_name_to_slot) + if extracted is None: + # Not a bucket-key constraint — that's fine, just skip it. The + # remaining predicate still describes a SUPERSET of matching + # rows; bucket pruning stays sound as long as we don't *add* + # constraints that aren't actually true. + continue + slot, values = extracted + if slot_values[slot] is not None: + # Same bucket-key column constrained twice in top-level AND + # (e.g. ``id IN (1,2,3) AND id IN (2,3,4)``). Mirror Java's + # ``retainAll``: keep the intersection, bail only when it is + # empty (the predicate is unsatisfiable). + new_values_set = set(values) + intersection = [v for v in slot_values[slot] + if v in new_values_set] + if not intersection: + return None + slot_values[slot] = intersection + else: + slot_values[slot] = values + + # Every bucket-key column must be constrained. + for v in slot_values: + if v is None: + return None + + # Cartesian-product cap. Above the cap the bucket set is essentially + # all buckets anyway; punting saves the hash computation. + total = 1 + for v in slot_values: + # An empty slot (e.g. all literals were null) collapses the + # product to 0 — observable behaviour: empty bucket set, drop + # everything. Mirrors Java. + total *= len(v) + if total > MAX_VALUES: + return None + + combinations = [list(combo) for combo in product(*slot_values)] + return _Selector(combinations, bucket_key_fields) diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py b/paimon-python/pypaimon/read/scanner/file_scanner.py index 39c740174643..9ce4e4ebff99 100755 --- a/paimon-python/pypaimon/read/scanner/file_scanner.py +++ b/paimon-python/pypaimon/read/scanner/file_scanner.py @@ -35,6 +35,8 @@ trim_and_transform_predicate) from pypaimon.read.scanner.append_table_split_generator import \ AppendTableSplitGenerator +from pypaimon.read.scanner.bucket_select_converter import \ + create_bucket_selector from pypaimon.read.scanner.data_evolution_split_generator import \ DataEvolutionSplitGenerator from pypaimon.read.scanner.primary_key_table_split_generator import \ @@ -207,6 +209,12 @@ def __init__( self._scanned_snapshot = None self._scanned_snapshot_id = None + # Predicate-driven bucket pruning (HASH_FIXED only). Mirrors Java + # BucketSelectConverter. Set on demand and reused across all + # _filter_manifest_entry calls; the inner _Selector caches the + # bucket set per ``total_buckets`` value. + self._bucket_selector = self._init_bucket_selector() + def schema_fields_func(schema_id: int): return self.table.schema_manager.get_schema(schema_id).fields @@ -339,9 +347,36 @@ def read_manifest_entries(self, manifest_files: List[ManifestFileMeta]) -> List[ return self.manifest_file_manager.read_entries_parallel( manifest_files, self._filter_manifest_entry, - max_workers=max_workers + max_workers=max_workers, + early_entry_filter=self._build_early_bucket_filter(), ) + def _build_early_bucket_filter(self): + """Compose the (bucket, total_buckets) -> bool used by the manifest + reader to drop entries before deserialising ``_FILE`` / partition. + + Mirrors the BucketFilter applied at Java's InternalRow stage in + ``ManifestEntryCache``. The signature is intentionally minimal: + per-partition predicate pre-evaluation would also need + ``(partition, bucket, total_buckets)``, but the current selector + is partition-agnostic. + """ + only_real = self.only_read_real_buckets + selector = self._bucket_selector + if not only_real and selector is None: + return None + + def _filter(bucket: int, total_buckets: int) -> bool: + if only_real and bucket < 0: + return False + if (selector is not None + and bucket >= 0 + and not selector(bucket, total_buckets)): + return False + return True + + return _filter + def with_shard(self, idx_of_this_subtask: int, number_of_para_subtasks: int) -> 'FileScanner': if idx_of_this_subtask >= number_of_para_subtasks: raise ValueError("idx_of_this_subtask must be less than number_of_para_subtasks") @@ -386,9 +421,55 @@ def _filter_manifest_file(self, file: ManifestFileMeta) -> bool: file.partition_stats, file.num_added_files + file.num_deleted_files) + def _init_bucket_selector(self): + """Build the predicate-driven bucket selector if (and only if) the + table is in HASH_FIXED mode and the predicate pins all bucket-key + fields to Equal/In literals. Anything else returns None — the + caller treats None as "no bucket-level pruning". + + Bucket-key fields come from ``TableSchema.logical_bucket_key_fields`` + — the same source the writer's ``FixedBucketRowKeyExtractor`` reads + from, which is what makes the read/write hash agreement a property + of the schema rather than of any particular extractor instance. + + Sound across rescale: ``_Selector`` caches per ``total_buckets``, + which can vary between manifest entries after a bucket rescale. + """ + if self.predicate is None: + return None + # ``bucket_mode()`` returns HASH_FIXED only when ``options.bucket() + # > 0``; other modes (DYNAMIC / POSTPONE / UNAWARE / CROSS_PARTITION) + # have no fixed hash → bucket mapping at write time and must NOT + # be pruned here. + try: + if self.table.bucket_mode() != BucketMode.HASH_FIXED: + return None + except Exception: + # Defensive: any catalog/proxy table that fails the mode check + # falls back to no pruning rather than crashing the scan. + return None + try: + bucket_key_fields = self.table.table_schema.logical_bucket_key_fields + except Exception: + # ``bucket_keys`` raises on misconfigured ``bucket-key`` (e.g. + # references an unknown column). The previous extractor-based + # path failed open here; preserve that — pruning is an + # optimisation, never a correctness requirement. + return None + if not bucket_key_fields: + return None + return create_bucket_selector(self.predicate, bucket_key_fields) + def _filter_manifest_entry(self, entry: ManifestEntry) -> bool: - if self.only_read_real_buckets and entry.bucket < 0: - return False + # NOTE: bucket-level filtering (``only_read_real_buckets`` + the + # predicate-driven selector) is enforced in the manifest reader's + # early filter (see ``_build_early_bucket_filter``) so rejected + # entries skip ``_FILE`` / partition decoding entirely. This + # method assumes that early filter has already run; a caller that + # bypasses ``read_entries_parallel`` and invokes this directly on + # raw entries MUST still apply ``_build_early_bucket_filter`` (or + # otherwise enforce ``bucket >= 0`` on POSTPONE tables) — this + # function alone is not sound on its own. if self.partition_key_predicate and not self.partition_key_predicate.test(entry.partition): return False # Get SimpleStatsEvolution for this schema diff --git a/paimon-python/pypaimon/schema/table_schema.py b/paimon-python/pypaimon/schema/table_schema.py index 53ddccfefcd2..789cf4e34cc5 100644 --- a/paimon-python/pypaimon/schema/table_schema.py +++ b/paimon-python/pypaimon/schema/table_schema.py @@ -64,6 +64,44 @@ def cross_partition_update(self) -> bool: # Return True if they don't contain all (cross-partition update) return not all(pk in self.primary_keys for pk in self.partition_keys) + @property + def bucket_keys(self) -> List[str]: + """Resolve the effective bucket-key column names. + + Resolution rule matches Java ``TableSchema.bucketKeys()``: prefer + the explicit ``bucket-key`` option; otherwise fall back to primary + keys with partition keys stripped (the same convention writers + use). + + Validation is intentionally narrower than Java's + ``originalBucketKeys()``: only ``unknown column name`` is checked + here. Java additionally enforces ``bucket-key`` ⊄ partition keys, + and (when primary keys are non-empty) ``bucket-key`` ⊆ primary + keys, but it does so once at schema construction. Doing the same + in a property would add per-read overhead and could surface + errors on tables already in the catalog. The narrow check here + is just enough to fail fast on the typo case. + """ + configured = self.options.get(CoreOptions.BUCKET_KEY.key()) + if configured and configured.strip(): + keys = [k.strip() for k in configured.split(',') if k.strip()] + field_names = {f.name for f in self.fields} + missing = [k for k in keys if k not in field_names] + if missing: + raise ValueError( + "bucket-key references unknown columns: {}".format(missing)) + return keys + return [pk for pk in self.primary_keys if pk not in self.partition_keys] + + @property + def logical_bucket_key_fields(self) -> List[DataField]: + """The ``DataField``s for ``bucket_keys``, in the order they were + declared. Mirrors Java ``TableSchema.logicalBucketKeyType()``. + """ + field_map = {f.name: f for f in self.fields} + return [field_map[name] for name in self.bucket_keys + if name in field_map] + def to_schema(self) -> Schema: return Schema( fields=self.fields, diff --git a/paimon-python/pypaimon/tests/pushdown_bucket_test.py b/paimon-python/pypaimon/tests/pushdown_bucket_test.py new file mode 100644 index 000000000000..b83283200e8c --- /dev/null +++ b/paimon-python/pypaimon/tests/pushdown_bucket_test.py @@ -0,0 +1,739 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +""" +Three-layer correctness tests for predicate-driven bucket pruning. + +Mirrors Java's ``BucketSelectConverter`` contract: PK Equal/In queries on +HASH_FIXED tables must touch only the bucket(s) the writer would have +placed those keys in. Two correctness obligations: + + 1. Sound: every bucket retained by the selector contains AT MOST a + superset of matching rows. Buckets that DO contain matching rows + are NEVER dropped — false-negative-free. + 2. Hash-consistent with writers: ``RowKeyExtractor`` (writer) and + ``BucketSelectConverter`` (reader) must agree on every literal. + This is what makes ``pk = 'X'`` read the bucket holding 'X'. + +Layered: + * Unit — direct calls to ``create_bucket_selector`` with crafted + predicates, asserting selector behaviour. + * Integration — real PK tables with multiple buckets; queries; assert + (a) result correctness, (b) bucket pruning happened. + * Property — randomly-seeded PK tables, random Equal/In predicates, + result == oracle. No hypothesis dependency (keeps + Python 3.6 compat). +""" + +import os +import random +import shutil +import tempfile +import unittest +from typing import Any, Dict, List + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema +from pypaimon.common.predicate_builder import PredicateBuilder +from pypaimon.read.scanner.bucket_select_converter import ( + MAX_VALUES, create_bucket_selector) +from pypaimon.schema.data_types import AtomicType, DataField +from pypaimon.write.row_key_extractor import (FixedBucketRowKeyExtractor, + _bucket_from_hash, + _hash_bytes_by_words) +from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer +from pypaimon.table.row.internal_row import RowKind + + +def _bigint_field(idx: int, name: str) -> DataField: + return DataField(idx, name, AtomicType('BIGINT', nullable=False)) + + +def _field(idx: int, name: str, type_name: str) -> DataField: + return DataField(idx, name, AtomicType(type_name, nullable=False)) + + +def _hash_bucket(values: List[Any], fields: List[DataField], total: int) -> int: + """Re-implement the writer's hash so unit tests can compute the + expected bucket without spinning up a real table.""" + row = GenericRow(values, fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + h = _hash_bytes_by_words(serialized[4:]) + return _bucket_from_hash(h, total) + + +# --------------------------------------------------------------------------- +# Layer 1 — Unit: drive ``create_bucket_selector`` with crafted predicates. +# --------------------------------------------------------------------------- +class BucketSelectConverterUnitTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.id_field = _bigint_field(0, 'id') + cls.val_field = _bigint_field(1, 'val') + cls.k1 = _bigint_field(0, 'k1') + cls.k2 = _bigint_field(1, 'k2') + cls.pb_id_val = PredicateBuilder([cls.id_field, cls.val_field]) + cls.pb_k1_k2 = PredicateBuilder([cls.k1, cls.k2]) + + # -- Equal / In on single bucket key --------------------------------- + def test_equal_on_single_bucket_key_yields_single_bucket(self): + sel = create_bucket_selector( + self.pb_id_val.equal('id', 42), [self.id_field]) + self.assertIsNotNone(sel, "PK Equal must produce a selector") + expected = _hash_bucket([42], [self.id_field], total=8) + for b in range(8): + self.assertEqual( + sel(b, 8), b == expected, + "only bucket {} must be kept (got {})".format(expected, b)) + + def test_in_on_single_bucket_key_unions_buckets(self): + sel = create_bucket_selector( + self.pb_id_val.is_in('id', [1, 2, 3, 100]), [self.id_field]) + expected = {_hash_bucket([v], [self.id_field], 8) + for v in (1, 2, 3, 100)} + for b in range(8): + self.assertEqual(sel(b, 8), b in expected) + + def test_or_of_equals_on_same_field_unions_buckets(self): + # ``id = 1 OR id = 2`` must equal ``id IN (1, 2)``. + pred = PredicateBuilder.or_predicates([ + self.pb_id_val.equal('id', 1), + self.pb_id_val.equal('id', 2), + ]) + sel = create_bucket_selector(pred, [self.id_field]) + expected = {_hash_bucket([v], [self.id_field], 8) for v in (1, 2)} + for b in range(8): + self.assertEqual(sel(b, 8), b in expected) + + # -- Composite bucket keys ------------------------------------------ + def test_composite_bucket_key_intersects_via_cartesian(self): + pred = PredicateBuilder.and_predicates([ + self.pb_k1_k2.is_in('k1', [1, 2]), + self.pb_k1_k2.equal('k2', 99), + ]) + sel = create_bucket_selector(pred, [self.k1, self.k2]) + expected = { + _hash_bucket([k1, 99], [self.k1, self.k2], 4) + for k1 in (1, 2) + } + for b in range(4): + self.assertEqual(sel(b, 4), b in expected) + + def test_composite_bucket_key_missing_one_field_returns_none(self): + pred = self.pb_k1_k2.equal('k1', 1) # k2 unconstrained + sel = create_bucket_selector(pred, [self.k1, self.k2]) + self.assertIsNone(sel, + "all bucket keys must be constrained or fall back") + + # -- Predicates that can't be reduced ------------------------------- + def test_non_bucket_key_predicate_returns_none(self): + sel = create_bucket_selector( + self.pb_id_val.equal('val', 5), [self.id_field]) + self.assertIsNone(sel, "predicate not on bucket key -> no selector") + + def test_range_predicate_on_bucket_key_returns_none(self): + sel = create_bucket_selector( + self.pb_id_val.greater_than('id', 100), [self.id_field]) + self.assertIsNone(sel, "ranges can't be turned into a finite bucket set") + + def test_or_with_non_bucket_key_returns_none(self): + # ``id = 1 OR val = 5`` — ``val`` isn't a bucket key, so the OR + # is not a pure bucket-key constraint. + pred = PredicateBuilder.or_predicates([ + self.pb_id_val.equal('id', 1), + self.pb_id_val.equal('val', 5), + ]) + sel = create_bucket_selector(pred, [self.id_field]) + self.assertIsNone(sel) + + def test_repeated_equal_on_same_key_with_empty_intersection_returns_none(self): + # ``id = 1 AND id = 2``: literal sets {1} and {2} intersect to + # empty; Java's ``retainAll`` would also bail here, since the + # predicate is unsatisfiable. + pred = PredicateBuilder.and_predicates([ + self.pb_id_val.equal('id', 1), + self.pb_id_val.equal('id', 2), + ]) + sel = create_bucket_selector(pred, [self.id_field]) + self.assertIsNone(sel) + + def test_repeated_in_on_same_key_intersects_literals(self): + # ``id IN (1,2,3) AND id IN (2,3,4)`` should now keep the + # intersection {2, 3} and prune to those buckets only. Used to + # bail with no selector before the Java parity fix. + pred = PredicateBuilder.and_predicates([ + self.pb_id_val.is_in('id', [1, 2, 3]), + self.pb_id_val.is_in('id', [2, 3, 4]), + ]) + sel = create_bucket_selector(pred, [self.id_field]) + self.assertIsNotNone(sel) + expected = {_hash_bucket([v], [self.id_field], 8) for v in (2, 3)} + for b in range(8): + self.assertEqual(sel(b, 8), b in expected) + + def test_and_with_unrelated_clause_is_unaffected(self): + # ``id = 7 AND val > 100`` — the ``val > 100`` part doesn't + # constrain buckets, but mustn't disqualify the ``id = 7`` part. + pred = PredicateBuilder.and_predicates([ + self.pb_id_val.equal('id', 7), + self.pb_id_val.greater_than('val', 100), + ]) + sel = create_bucket_selector(pred, [self.id_field]) + self.assertIsNotNone(sel) + expected = _hash_bucket([7], [self.id_field], 4) + for b in range(4): + self.assertEqual(sel(b, 4), b == expected) + + # -- Cap & degenerate edge cases ------------------------------------ + def test_cartesian_above_max_values_returns_none(self): + # Two columns of size > sqrt(MAX_VALUES) → product > MAX_VALUES. + size = 33 # 33 * 33 = 1089 > 1000 + pred = PredicateBuilder.and_predicates([ + self.pb_k1_k2.is_in('k1', list(range(size))), + self.pb_k1_k2.is_in('k2', list(range(size))), + ]) + self.assertGreater(size * size, MAX_VALUES) + sel = create_bucket_selector(pred, [self.k1, self.k2]) + self.assertIsNone(sel) + + def test_null_only_literal_drops_everything(self): + # ``id IN (NULL)`` after null-stripping has zero literals; the + # cartesian product is empty → selector matches no buckets. Same + # behaviour as Java. + pred = self.pb_id_val.is_in('id', [None]) + sel = create_bucket_selector(pred, [self.id_field]) + self.assertIsNotNone(sel) + for b in range(4): + self.assertFalse(sel(b, 4), + "all-null literal collapses bucket set to empty") + + def test_no_predicate_returns_none(self): + self.assertIsNone(create_bucket_selector(None, [self.id_field])) + + def test_no_bucket_keys_returns_none(self): + self.assertIsNone( + create_bucket_selector(self.pb_id_val.equal('id', 1), [])) + + # -- Selector cache + rescale ------------------------------------- + def test_selector_caches_per_total_buckets(self): + """Selector must answer correctly when the same query applies to + different ``total_buckets`` values (the rescale scenario).""" + sel = create_bucket_selector( + self.pb_id_val.equal('id', 42), [self.id_field]) + for total in (4, 8, 16, 32): + expected = _hash_bucket([42], [self.id_field], total) + self.assertTrue(sel(expected, total)) + other = (expected + 1) % total + self.assertFalse(sel(other, total)) + + def test_non_positive_total_buckets_fails_open(self): + """Manifest entries can carry ``total_buckets <= 0`` for legacy / + special bucket modes. Pruning MUST fail open — returning False + would silently drop rows the writer placed in those entries. + This is correctness, not performance: the soundness contract + forbids false-negatives.""" + sel = create_bucket_selector( + self.pb_id_val.equal('id', 1), [self.id_field]) + for total in (0, -1, -2): + self.assertTrue(sel(0, total), + "total_buckets={} must be kept (fail open)".format(total)) + self.assertTrue(sel(-1, total)) + self.assertTrue(sel(99, total)) + + # -- Bucket-key column types beyond BIGINT -------------------------- + def test_string_bucket_key_yields_correct_bucket(self): + """STRING uses a different ``GenericRowSerializer`` path (utf-8 + encode + variable-part offset) — verify writer/reader agree on + its byte layout independent of the BIGINT happy path.""" + sf = _field(0, 'sk', 'STRING') + vf = _bigint_field(1, 'val') + pb = PredicateBuilder([sf, vf]) + sel = create_bucket_selector(pb.equal('sk', 'hello'), [sf]) + self.assertIsNotNone(sel) + expected = _hash_bucket(['hello'], [sf], total=8) + for b in range(8): + self.assertEqual(sel(b, 8), b == expected) + + def test_int_bucket_key_yields_correct_bucket(self): + """INT (32-bit) and BIGINT (64-bit) hit different struct.pack + paths in the serializer — guard the smaller width.""" + intf = _field(0, 'i', 'INT') + vf = _bigint_field(1, 'val') + pb = PredicateBuilder([intf, vf]) + sel = create_bucket_selector(pb.equal('i', 7), [intf]) + self.assertIsNotNone(sel) + expected = _hash_bucket([7], [intf], total=4) + for b in range(4): + self.assertEqual(sel(b, 4), b == expected) + + # -- Hash-divergence-prone types refuse to build a selector -------- + def test_decimal_bucket_key_disables_pruning(self): + """DECIMAL columns risk silent hash divergence between writer + (Decimal) and reader-supplied ``float`` literals. Soundness + contract demands fail-open: refuse to build a selector at all.""" + df = _field(0, 'd', 'DECIMAL(10, 2)') + vf = _bigint_field(1, 'val') + pb = PredicateBuilder([df, vf]) + from decimal import Decimal + sel = create_bucket_selector(pb.equal('d', Decimal('1.50')), [df]) + self.assertIsNone( + sel, "DECIMAL bucket-key column must disable pruning") + + def test_array_bucket_key_disables_pruning(self): + """Composite types (ARRAY/MAP/ROW/MULTISET/VARIANT/BLOB) have no + cross-validated byte alignment with Java's ``BinaryRow`` — until + that exists, refuse to prune on them.""" + # Hand-roll a DataField whose AtomicType reports an ARRAY type + # name; the converter inspects ``field.type.type`` only. + af = DataField(0, 'arr', AtomicType('ARRAY')) + vf = _bigint_field(1, 'val') + pb = PredicateBuilder([af, vf]) + sel = create_bucket_selector(pb.equal('arr', [1]), [af]) + self.assertIsNone( + sel, "ARRAY bucket-key column must disable pruning") + + def test_timestamp_bucket_key_disables_pruning(self): + """TIMESTAMP columns serialise via ``value.timestamp()`` whose + result depends on the process timezone for naive datetimes — + writer and reader running in different TZs would disagree.""" + tf = _field(0, 't', 'TIMESTAMP(3)') + vf = _bigint_field(1, 'val') + pb = PredicateBuilder([tf, vf]) + from datetime import datetime + sel = create_bucket_selector( + pb.equal('t', datetime(2026, 1, 1)), [tf]) + self.assertIsNone( + sel, "TIMESTAMP bucket-key column must disable pruning") + + def test_type_mismatched_literal_fails_open_not_crash(self): + """If the user constructs a predicate whose literal type doesn't + match the bucket-key column's atomic type — e.g. a STRING literal + on a BIGINT column — ``GenericRowSerializer`` raises during the + deferred hash inside ``_Selector``. The selector MUST swallow the + exception and fail open (return True for every bucket) rather + than propagate it. Crashing the entire scan with an opaque + ``struct.error`` is a worse user experience than silently + skipping bucket pruning, and the soundness contract still + forbids false-negatives.""" + sel = create_bucket_selector( + self.pb_id_val.equal('id', 'not-an-int'), [self.id_field]) + # Construction itself succeeds (no eager hashing). + self.assertIsNotNone(sel) + # Calling the selector must NOT raise; instead it returns True + # for every (bucket, total_buckets), preserving soundness. + for total in (4, 8): + for b in range(total): + self.assertTrue(sel(b, total), + "type-mismatched literal must fail open, " + "not crash (bucket={}, total={})".format(b, total)) + + +# --------------------------------------------------------------------------- +# Layer 2 — Integration: real tables, public API, assert correctness AND +# that pruning actually fired (otherwise we're not testing the optimisation, +# only that we didn't break full-scan). +# --------------------------------------------------------------------------- +class BucketPruningIntegrationTest(unittest.TestCase): + + NUM_BUCKETS = 8 + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', False) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_pk_table(self, name: str, num_buckets: int = NUM_BUCKETS, + bucket_key: str = None) -> Any: + opts = {'bucket': str(num_buckets), 'file.format': 'parquet'} + if bucket_key is not None: + opts['bucket-key'] = bucket_key + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('val', pa.int64()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, primary_keys=['id'], options=opts) + full = 'default.{}'.format(name) + self.catalog.create_table(full, schema, False) + return self.catalog.get_table(full) + + def _write(self, table, rows: List[Dict]): + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('val', pa.int64()), + ]) + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + + def _read_with(self, table, predicate=None): + rb = table.new_read_builder() + if predicate is not None: + rb = rb.with_filter(predicate) + splits = rb.new_scan().plan().splits() + if not splits: + return [], splits + return rb.new_read().to_arrow(splits).to_pylist(), splits + + @staticmethod + def _split_buckets(splits) -> set: + """Collect the distinct bucket numbers actually returned in a plan.""" + return {s.bucket for s in splits} + + @staticmethod + def _expected_buckets(table, ids, value_field='val') -> set: + """Use the writer's RowKeyExtractor to compute the bucket(s) the + rows for ``ids`` were written into. Cross-check against the + reader's selector — divergence indicates read/write hash drift.""" + ext = FixedBucketRowKeyExtractor(table.table_schema) + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + (value_field, pa.int64()), + ]) + out = set() + for i in ids: + arr = pa.RecordBatch.from_pylist( + [{'id': i, value_field: 0}], schema=pa_schema) + out.update(ext._extract_buckets_batch(arr)) + return out + + # -- Equal on PK ----------------------------------------------------- + def test_pk_equal_only_reads_target_bucket(self): + table = self._create_pk_table('int_eq') + rows = [{'id': i, 'val': i * 11} for i in range(100)] + self._write(table, rows) + + target_id = 42 + pred = table.new_read_builder().new_predicate_builder().equal( + 'id', target_id) + got, splits = self._read_with(table, pred) + + # Correctness: row for id=42 returned (and only that). + self.assertEqual(got, [{'id': 42, 'val': 42 * 11}]) + + # Pruning effectiveness AND hash correctness: the touched bucket + # must equal the bucket the writer placed id=42 into. Asserting + # only ``len == 1`` would mask a hash drift that picks the wrong + # single bucket. + self.assertEqual(self._split_buckets(splits), + self._expected_buckets(table, [target_id]), + "PK equal must touch exactly the writer's bucket") + + def test_pk_in_reads_only_target_buckets(self): + table = self._create_pk_table('int_in') + rows = [{'id': i, 'val': i * 7} for i in range(200)] + self._write(table, rows) + + ids = [3, 17, 99, 150] + pred = table.new_read_builder().new_predicate_builder().is_in( + 'id', ids) + got, splits = self._read_with(table, pred) + got_sorted = sorted(got, key=lambda r: r['id']) + self.assertEqual(got_sorted, + [{'id': i, 'val': i * 7} for i in sorted(ids)]) + + actual = self._split_buckets(splits) + expected_buckets = self._expected_buckets(table, ids) + # Equality (not subset): under the single-commit setup every + # target bucket actually has a file, so the planner must produce + # exactly the writer's bucket set. ``issubset`` would mask a + # selector that's overly aggressive on a subset of the IN list. + self.assertEqual(actual, expected_buckets, + "got {}, expected {}".format(actual, expected_buckets)) + + # -- Predicates that should NOT prune ------------------------------- + def test_value_only_predicate_falls_back_to_full_scan(self): + """``val < X`` doesn't constrain the PK → selector must be None + and no bucket pruning may fire. Both checked: result correctness + AND the explicit "selector is None" property.""" + table = self._create_pk_table('val_only') + rows = [{'id': i, 'val': i} for i in range(100)] + self._write(table, rows) + + pred = table.new_read_builder().new_predicate_builder().less_than( + 'val', 30) + got, splits = self._read_with(table, pred) + self.assertEqual(sorted([r['id'] for r in got]), list(range(30))) + + # Inspect the scanner's bucket selector to prove pruning DIDN'T + # fire — without this assertion the test would also pass under a + # buggy selector that prunes wrongly but happens to keep the + # rows we picked. + rb = table.new_read_builder().with_filter(pred) + scan = rb.new_scan() + self.assertIsNone(scan.file_scanner._bucket_selector, + "value-only predicate must NOT produce a selector") + + def test_range_on_pk_falls_back_to_full_scan(self): + """``id > X`` is a range, not Equal/In, so cannot derive a bucket + set. Selector returns None — result must still be exact.""" + table = self._create_pk_table('pk_range') + rows = [{'id': i, 'val': i} for i in range(50)] + self._write(table, rows) + + pred = table.new_read_builder().new_predicate_builder().greater_or_equal( + 'id', 40) + got, _ = self._read_with(table, pred) + self.assertEqual(sorted([r['id'] for r in got]), list(range(40, 50))) + + # -- Mixed predicate: Equal on PK AND range on val ------------------ + def test_pk_equal_with_unrelated_value_predicate_still_prunes(self): + table = self._create_pk_table('int_eq_with_val') + rows = [{'id': i, 'val': i} for i in range(50)] + self._write(table, rows) + + pb = table.new_read_builder().new_predicate_builder() + pred = pb.and_predicates([ + pb.equal('id', 25), + pb.greater_than('val', 20), + ]) + got, splits = self._read_with(table, pred) + self.assertEqual(got, [{'id': 25, 'val': 25}]) + self.assertEqual(self._split_buckets(splits), + self._expected_buckets(table, [25]), + "Equal on PK still narrows to the writer's bucket " + "even when AND'd with a non-bucket-key predicate") + + def test_early_filter_skips_full_entry_decode_for_pruned_buckets(self): + """Entries the bucket selector rejects must never reach + ``GenericRowDeserializer.from_bytes`` for their partition / key + stats. Without the early filter the count would scale with the + manifest entry count; with it, only the surviving entries pay + the deserialisation cost.""" + from unittest import mock + + from pypaimon.table.row import generic_row + + table = self._create_pk_table('early_filter') + # 8 separate single-row commits → 8 manifest entries each touching + # a different bucket. ``pk = X`` should reach exactly one of them. + for i in range(self.NUM_BUCKETS): + self._write(table, [{'id': i, 'val': i * 11}]) + + pred = table.new_read_builder().new_predicate_builder().equal('id', 0) + rb = table.new_read_builder().with_filter(pred) + + real_from_bytes = generic_row.GenericRowDeserializer.from_bytes + calls = {'n': 0} + + def counting(*args, **kwargs): + calls['n'] += 1 + return real_from_bytes(*args, **kwargs) + + with mock.patch.object(generic_row.GenericRowDeserializer, + 'from_bytes', + side_effect=counting): + splits = rb.new_scan().plan().splits() + got = rb.new_read().to_arrow(splits).to_pylist() if splits else [] + + self.assertEqual(got, [{'id': 0, 'val': 0}]) + # Each surviving entry decodes partition + min_key + max_key + # (3 ``from_bytes`` calls). Allow a small slack in case the planner + # touches extras, but assert it is well below 8 entries × 3 = 24. + self.assertLess( + calls['n'], 3 * self.NUM_BUCKETS, + "early filter should skip from_bytes for pruned entries; " + "got {} calls (would be {}+ without the filter)".format( + calls['n'], 3 * self.NUM_BUCKETS)) + + def test_init_bucket_selector_fails_open_when_bucket_keys_raises(self): + """``TableSchema.bucket_keys`` raises if ``bucket-key`` references + an unknown column. The pre-Java-alignment selector path used to + catch ``Exception`` from instantiating ``FixedBucketRowKeyExtractor`` + and silently skip pruning; that property must survive the move + of bucket-key resolution onto ``TableSchema``. Crashing the scan + on a misconfiguration would be worse than skipping the + optimisation.""" + table = self._create_pk_table('init_fails_open') + self._write(table, [{'id': 1, 'val': 1}]) + # Mutate the in-memory schema options to a broken value to + # simulate a corrupted/migrated catalog without rewriting it. + table.table_schema.options['bucket-key'] = 'nope_no_such_column' + + rb = table.new_read_builder().with_filter( + table.new_read_builder().new_predicate_builder().equal('id', 1)) + scanner = rb.new_scan().file_scanner + # Must NOT raise: the broken option falls back to "no pruning", + # and the scan still finds the row. + self.assertIsNone(scanner._init_bucket_selector()) + got, _ = self._read_with(table, scanner.predicate) + self.assertEqual(got, [{'id': 1, 'val': 1}]) + + # -- Explicit bucket-key option ------------------------------------ + def test_bucket_key_option_overrides_pk_for_pruning(self): + """When the ``bucket-key`` option is set explicitly, the bucket + derivation must use it — not the trimmed primary keys. This is + the path that catches read/write hash divergence if a refactor + forgets the option.""" + # PK = id, bucket-key = id explicitly (single key but exercises + # the explicit-config branch in ``_init_bucket_selector``). + table = self._create_pk_table('explicit_bk', bucket_key='id') + rows = [{'id': i, 'val': i * 3} for i in range(40)] + self._write(table, rows) + + pred = table.new_read_builder().new_predicate_builder().equal('id', 17) + got, splits = self._read_with(table, pred) + self.assertEqual(got, [{'id': 17, 'val': 51}]) + self.assertEqual(self._split_buckets(splits), + self._expected_buckets(table, [17])) + + +# --------------------------------------------------------------------------- +# Layer 3 — Property: random PK tables, random Equal/In predicates, +# correctness vs oracle. +# --------------------------------------------------------------------------- +class BucketPruningPropertyTest(unittest.TestCase): + + SEED = 0xB0CC + TRIALS = 30 + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', False) + cls.rnd = random.Random(cls.SEED) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _make_table(self, idx: int, num_buckets: int): + pa_schema = pa.schema([ + pa.field('k', pa.int64(), nullable=False), + ('v', pa.int64()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + primary_keys=['k'], + options={'bucket': str(num_buckets), 'file.format': 'parquet'}, + ) + name = 'default.bp_{}'.format(idx) + self.catalog.create_table(name, schema, False) + return self.catalog.get_table(name) + + def _write(self, table, rows): + pa_schema = pa.schema([ + pa.field('k', pa.int64(), nullable=False), + ('v', pa.int64()), + ]) + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + + @staticmethod + def _expected_buckets(table, keys) -> set: + """Independent oracle: writer's bucket placement for the given keys.""" + ext = FixedBucketRowKeyExtractor(table.table_schema) + pa_schema = pa.schema([ + pa.field('k', pa.int64(), nullable=False), + ('v', pa.int64()), + ]) + out = set() + for k in keys: + arr = pa.RecordBatch.from_pylist( + [{'k': k, 'v': 0}], schema=pa_schema) + out.update(ext._extract_buckets_batch(arr)) + return out + + def test_property_pk_equal_correctness(self): + for trial in range(self.TRIALS): + num_buckets = self.rnd.choice([2, 4, 8, 16]) + table = self._make_table(trial, num_buckets) + keys = self.rnd.sample(range(1000), self.rnd.randint(20, 100)) + rows = [{'k': k, 'v': k * 13} for k in keys] + self._write(table, rows) + + target = self.rnd.choice(keys) + pb = table.new_read_builder().new_predicate_builder() + pred = pb.equal('k', target) + rb = table.new_read_builder().with_filter(pred) + splits = rb.new_scan().plan().splits() + if splits: + got = rb.new_read().to_arrow(splits).to_pylist() + else: + got = [] + self.assertEqual(got, [{'k': target, 'v': target * 13}], + "trial {} buckets={} target={}: result mismatch" + .format(trial, num_buckets, target)) + # Pruning fired AND picked the writer's bucket. Without this + # cross-check a fail-open selector (i.e. no pruning) would + # still pass the result-equality assertion above. + self.assertEqual(self._split_buckets(splits), + self._expected_buckets(table, [target]), + "trial {}: bucket set != writer's placement" + .format(trial)) + + def test_property_pk_in_correctness(self): + for trial in range(self.TRIALS): + num_buckets = self.rnd.choice([2, 4, 8, 16]) + offset = self.TRIALS + trial # avoid name clash with prev test + table = self._make_table(offset, num_buckets) + keys = self.rnd.sample(range(1000), self.rnd.randint(20, 100)) + rows = [{'k': k, 'v': k * 13} for k in keys] + self._write(table, rows) + + target_n = self.rnd.randint(1, min(10, len(keys))) + targets = self.rnd.sample(keys, target_n) + pb = table.new_read_builder().new_predicate_builder() + pred = pb.is_in('k', targets) + rb = table.new_read_builder().with_filter(pred) + splits = rb.new_scan().plan().splits() + if splits: + got = rb.new_read().to_arrow(splits).to_pylist() + else: + got = [] + got_sorted = sorted(got, key=lambda r: r['k']) + want = sorted( + [{'k': k, 'v': k * 13} for k in targets], + key=lambda r: r['k']) + self.assertEqual(got_sorted, want, + "trial {}: IN result mismatch".format(trial)) + self.assertEqual(self._split_buckets(splits), + self._expected_buckets(table, targets), + "trial {}: IN bucket set != writer's placement" + .format(trial)) + + @staticmethod + def _split_buckets(splits) -> set: + return {s.bucket for s in splits} + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/tests/table_schema_test.py b/paimon-python/pypaimon/tests/table_schema_test.py new file mode 100644 index 000000000000..d42eef2ab8ca --- /dev/null +++ b/paimon-python/pypaimon/tests/table_schema_test.py @@ -0,0 +1,93 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import unittest + +from pypaimon.schema.data_types import AtomicType, DataField +from pypaimon.schema.table_schema import TableSchema + + +def _bigint_field(idx: int, name: str) -> DataField: + return DataField(idx, name, AtomicType('BIGINT', nullable=False)) + + +def _string_field(idx: int, name: str) -> DataField: + return DataField(idx, name, AtomicType('STRING')) + + +class TableSchemaBucketKeysTest(unittest.TestCase): + """Cover the ``bucket-key`` resolution lifted onto TableSchema. + + Mirrors Java ``TableSchema.bucketKeys()`` / ``logicalBucketKeyType()``. + """ + + def _schema(self, primary_keys=None, partition_keys=None, options=None): + fields = [ + _bigint_field(0, 'id'), + _string_field(1, 'region'), + _bigint_field(2, 'val'), + ] + return TableSchema( + id=0, + fields=fields, + partition_keys=partition_keys or [], + primary_keys=primary_keys or [], + options=options or {}, + ) + + def test_explicit_bucket_key_option_returns_those_columns(self): + schema = self._schema( + primary_keys=['id'], + options={'bucket-key': 'region,val'}, + ) + self.assertEqual(schema.bucket_keys, ['region', 'val']) + + fields = schema.logical_bucket_key_fields + self.assertEqual([f.name for f in fields], ['region', 'val']) + + def test_no_bucket_key_falls_back_to_trimmed_primary_keys(self): + # PK includes a partition column; trimmed bucket keys drop it. + schema = self._schema( + primary_keys=['region', 'id'], + partition_keys=['region'], + ) + self.assertEqual(schema.bucket_keys, ['id']) + self.assertEqual( + [f.name for f in schema.logical_bucket_key_fields], ['id']) + + def test_no_bucket_key_no_primary_keys_returns_empty(self): + schema = self._schema() + self.assertEqual(schema.bucket_keys, []) + self.assertEqual(schema.logical_bucket_key_fields, []) + + def test_unknown_bucket_key_column_raises(self): + schema = self._schema(options={'bucket-key': 'nope'}) + with self.assertRaises(ValueError): + _ = schema.bucket_keys + + def test_whitespace_only_option_falls_back(self): + # Whitespace-only ``bucket-key`` mirrors an unset option. + schema = self._schema( + primary_keys=['id'], + options={'bucket-key': ' '}, + ) + self.assertEqual(schema.bucket_keys, ['id']) + + +if __name__ == '__main__': + unittest.main() diff --git a/paimon-python/pypaimon/write/row_key_extractor.py b/paimon-python/pypaimon/write/row_key_extractor.py index 2f09e6577b20..c30b63e21b3c 100644 --- a/paimon-python/pypaimon/write/row_key_extractor.py +++ b/paimon-python/pypaimon/write/row_key_extractor.py @@ -125,18 +125,12 @@ def __init__(self, table_schema: TableSchema): if self.num_buckets <= 0: raise ValueError(f"Fixed bucket mode requires bucket > 0, got {self.num_buckets}") - bucket_key_option = options.bucket_key() - if bucket_key_option and bucket_key_option.strip(): - self.bucket_keys = [k.strip() for k in bucket_key_option.split(',')] - else: - self.bucket_keys = [pk for pk in table_schema.primary_keys - if pk not in table_schema.partition_keys] - + # Bucket-key resolution lives on TableSchema (mirrors Java + # ``TableSchema.bucketKeys()`` / ``logicalBucketKeyType()``); reuse + # it so any reader path that walks the same logic stays in sync. + self.bucket_keys = table_schema.bucket_keys self.bucket_key_indices = self._get_field_indices(self.bucket_keys) - field_map = {f.name: f for f in table_schema.fields} - self._bucket_key_fields = [ - field_map[name] for name in self.bucket_keys if name in field_map - ] + self._bucket_key_fields = table_schema.logical_bucket_key_fields def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]: columns = [data.column(i) for i in self.bucket_key_indices]