Skip to content

[Experimental] Add FilterComponent for row-level example filtering#7868

Open
vkarampudi wants to merge 1 commit into
tensorflow:masterfrom
vkarampudi:feature/filter-component
Open

[Experimental] Add FilterComponent for row-level example filtering#7868
vkarampudi wants to merge 1 commit into
tensorflow:masterfrom
vkarampudi:feature/filter-component

Conversation

@vkarampudi

Copy link
Copy Markdown
Collaborator

Description

This PR introduces an experimental FilterComponent to TFX, addressing a long-standing feature request (Issue #6909) for row-level example filtering within pipelines.

Currently, TFX has no built-in component to filter anomalous or unwanted examples after ingestion. The Transform component (TFT) is strictly map-only and does not support dropping rows. As a result, users have had to write custom Apache Beam code or perform out-of-band pre-processing. The new FilterComponent provides a clean, native, and reusable way to filter datasets split-by-split using a user-defined Python function.

Proposed Changes

  1. Component Definition (component.py):
    • Created FilterSpec and FilterComponent.
    • Accepts an input Examples channel and a string import path to a Python filter function (filter_fn_path).
    • Outputs a filtered_examples Examples channel.
  2. Beam Executor (executor.py):
    • Implements a high-performance Apache Beam pipeline.
    • Automatically decodes all splits present in the input Examples artifact (e.g., train, eval).
    • Dynamically imports the user-defined filter function and applies beam.Filter.
    • Writes the output splits using standard TFRecord Gzip compression while preserving all metadata (spans, versions, split names) for downstream components.
  3. Tests:
    • component_test.py: Verifies component instantiation and spec parsing.
    • executor_test.py: Runs a full integration test with simulated train and eval splits, executing the Beam filter pipeline and verifying correct output records.

Usage Example

# 1. Define a filter function in a shared module (e.g., my_filters.py)
def age_filter_fn(serialized_example: bytes) -> bool:
  example = tf.train.Example()
  example.ParseFromString(serialized_example)
  age_feature = example.features.feature.get('age')
  if age_feature:
    return age_feature.int64_list.value[0] > 18
  return False

# 2. Add the FilterComponent to your pipeline
from tfx.components.experimental.filter.component import FilterComponent

filter_data = FilterComponent(
    examples=example_gen.outputs['examples'],
    filter_fn_path='my_filters.age_filter_fn'
)

# 3. Pass the filtered output to Transform/Trainer
transform = Transform(
    examples=filter_data.outputs['filtered_examples'],
    schema=schema_gen.outputs['schema'],
    preprocessing_fn=...
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant