From 2e4bd3bb52d5dc583c8ba39c48a0de07e3bef7b6 Mon Sep 17 00:00:00 2001 From: Pooya Moradi Date: Tue, 26 May 2026 07:40:58 +0000 Subject: [PATCH] Add reward_functions_path + reward_functions CLI knobs for custom rewards Currently the reward stack is hardcoded to a 3-fn list: [match_format_exactly, match_format_approximately, check_numbers]. Replacing it requires editing train_rl.py. Add two new config fields: - reward_functions_path: path to a user Python file - reward_functions: comma-separated list of function names to import When both are set, the built-in reward stack is REPLACED entirely by the user-provided functions (so users can pin a single VTC-style partial-credit reward, swap in a math_verify-based scorer, etc., without editing maxtext). Each user function must accept (prompts, completions, tmvp_config, **kwargs) and return a list of floats. Default (both empty) keeps existing behavior unchanged. Reuses `_load_custom_callable` helper added in the previous commit. --- src/maxtext/configs/post_train/rl.yml | 7 ++++ src/maxtext/configs/types.py | 14 +++++++ .../trainers/post_train/rl/train_rl.py | 37 ++++++++++++++--- tests/post_training/unit/train_rl_test.py | 41 +++++++++++++++++++ 4 files changed, 93 insertions(+), 6 deletions(-) diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 9ba55c296f..5775368b1f 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -238,6 +238,13 @@ eval_dataset_name: 'openai/gsm8k' # tmvp_config, x) -> dict with keys {prompts, question, answer}. When empty # (default), the built-in utils_rl.process_data is used. dataset_processor_path: '' +# Optional: path to a user-provided Python file with custom reward functions, +# AND a comma-separated list of function names to import from that file. +# Each function signature: fn(prompts, completions, tmvp_config, **kwargs) -> list[float]. +# When both are set, the built-in [match_format_exactly, match_format_approximately, +# check_numbers] reward list is REPLACED entirely. +reward_functions_path: '' +reward_functions: '' train_split: 'train' eval_split: 'test' hf_name: 'main' # subset of Hugging Face dataset diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b70b7238d3..f16aa6dc30 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2056,6 +2056,20 @@ class RLDataset(BaseModel): "When set, replaces the built-in dataset processor for custom datasets." ), ) + reward_functions_path: str = Field( + "", + description=( + "Optional path to a user Python file containing custom reward functions. " + "Used with `reward_functions` to fully replace the built-in reward stack." + ), + ) + reward_functions: str = Field( + "", + description=( + "Comma-separated names of reward functions to import from `reward_functions_path`. " + "Each function signature: (prompts, completions, tmvp_config, **kwargs) -> list[float]." + ), + ) class RLEvaluation(BaseModel): diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index fed793be21..ed14153806 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -46,7 +46,7 @@ from __future__ import annotations import contextlib from functools import wraps -from typing import Any, Optional, Sequence +from typing import Any, Callable, Optional, Sequence import datasets import grain @@ -363,6 +363,31 @@ def _use_raw_prompt(x): return train_dataset, test_dataset +def build_reward_fns(trainer_config: Any, make_reward_fn: Callable) -> list: + """Build the reward-function stack for the RL trainer. + + `reward_functions_path` is a filesystem path to a Python file and + `reward_functions` is a comma-separated list of function names to import from + it. When both are set, the built-in stack is REPLACED entirely by the + user-provided callables (so users have full control over their reward stack). + Otherwise the default + (`match_format_exactly`, `match_format_approximately`, `check_numbers`) stack + is used. Every reward function is wrapped via `make_reward_fn`. + """ + custom_rewards_path = getattr(trainer_config, "reward_functions_path", "") or "" + custom_rewards_names = getattr(trainer_config, "reward_functions", "") or "" + if custom_rewards_path and custom_rewards_names: + names = [n.strip() for n in custom_rewards_names.split(",") if n.strip()] + reward_fns = [make_reward_fn(utils_rl.load_custom_callable(custom_rewards_path, n)) for n in names] + max_logging.log(f"reward_fns: using {len(reward_fns)} custom reward function(s) {names} from {custom_rewards_path}") + return reward_fns + return [ + make_reward_fn(utils_rl.match_format_exactly), + make_reward_fn(utils_rl.match_format_approximately), + make_reward_fn(utils_rl.check_numbers), + ] + + def create_rl_components( trainer_config, sampler_config, @@ -526,11 +551,11 @@ def _reward_fn(**kwargs): return _reward_fn - reward_fns = [ # type: ignore - make_reward_fn(utils_rl.match_format_exactly), - make_reward_fn(utils_rl.match_format_approximately), - make_reward_fn(utils_rl.check_numbers), - ] + # Optional user-provided reward functions: when `reward_functions_path` and + # `reward_functions` are both set the built-in stack is replaced entirely by + # the user-provided callables. Each function must accept `prompts`, + # `completions`, `tmvp_config`, and `**kwargs` and return a list of floats. + reward_fns = build_reward_fns(trainer_config, make_reward_fn) # Create RL trainer max_logging.log("Setting up RL trainer...") diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index be66a49cc0..bb89424aa6 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -495,6 +495,47 @@ def test_rl_train_invalid_optimizer_memory_host_offload(self, mock_setup): with self.assertRaisesRegex(ValueError, "optimizer_memory_host_offload=True is not supported"): train_rl._rl_train_impl([], {}) # pylint: disable=protected-access + @pytest.mark.cpu_only + def test_build_reward_fns_defaults_when_no_custom(self): + """With neither knob set, the built-in 3-fn stack is returned.""" + trainer_config = SimpleNamespace(reward_functions_path="", reward_functions="") + reward_fns = train_rl.build_reward_fns(trainer_config, make_reward_fn=lambda fn: fn) + self.assertEqual( + reward_fns, + [ + train_rl.utils_rl.match_format_exactly, + train_rl.utils_rl.match_format_approximately, + train_rl.utils_rl.check_numbers, + ], + ) + + @pytest.mark.cpu_only + def test_build_reward_fns_custom_replaces_builtins(self): + """When both knobs are set, the stack is the user-provided functions only.""" + trainer_config = SimpleNamespace( + reward_functions_path="/tmp/my_rewards.py", + reward_functions="reward_a, reward_b", + ) + loaded = {"reward_a": object(), "reward_b": object()} + with mock.patch.object( + train_rl.utils_rl, "load_custom_callable", side_effect=lambda path, name: loaded[name] + ) as mock_load: + reward_fns = train_rl.build_reward_fns(trainer_config, make_reward_fn=lambda fn: fn) + self.assertEqual(reward_fns, [loaded["reward_a"], loaded["reward_b"]]) + mock_load.assert_has_calls( + [ + mock.call("/tmp/my_rewards.py", "reward_a"), + mock.call("/tmp/my_rewards.py", "reward_b"), + ] + ) + + @pytest.mark.cpu_only + def test_build_reward_fns_partial_config_falls_back(self): + """If only one of the two knobs is set, the built-in stack is used.""" + trainer_config = SimpleNamespace(reward_functions_path="/tmp/my_rewards.py", reward_functions="") + reward_fns = train_rl.build_reward_fns(trainer_config, make_reward_fn=lambda fn: fn) + self.assertEqual(len(reward_fns), 3) + if __name__ == "__main__": unittest.main()