Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ eval_dataset_name: 'openai/gsm8k'
# tmvp_config, x) -> dict with keys {prompts, question, answer}. When empty
# (default), the built-in utils_rl.process_data is used.
dataset_processor_path: ''
# Optional: path to a user-provided Python file with custom reward functions,
# AND a comma-separated list of function names to import from that file.
# Each function signature: fn(prompts, completions, tmvp_config, **kwargs) -> list[float].
# When both are set, the built-in [match_format_exactly, match_format_approximately,
# check_numbers] reward list is REPLACED entirely.
reward_functions_path: ''
reward_functions: ''
train_split: 'train'
eval_split: 'test'
hf_name: 'main' # subset of Hugging Face dataset
Expand Down
14 changes: 14 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,6 +2056,20 @@ class RLDataset(BaseModel):
"When set, replaces the built-in dataset processor for custom datasets."
),
)
reward_functions_path: str = Field(
"",
description=(
"Optional path to a user Python file containing custom reward functions. "
"Used with `reward_functions` to fully replace the built-in reward stack."
),
)
reward_functions: str = Field(
"",
description=(
"Comma-separated names of reward functions to import from `reward_functions_path`. "
"Each function signature: (prompts, completions, tmvp_config, **kwargs) -> list[float]."
),
)


class RLEvaluation(BaseModel):
Expand Down
37 changes: 31 additions & 6 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from __future__ import annotations
import contextlib
from functools import wraps
from typing import Any, Optional, Sequence
from typing import Any, Callable, Optional, Sequence

import datasets
import grain
Expand Down Expand Up @@ -363,6 +363,31 @@ def _use_raw_prompt(x):
return train_dataset, test_dataset


def build_reward_fns(trainer_config: Any, make_reward_fn: Callable) -> list:
"""Build the reward-function stack for the RL trainer.

`reward_functions_path` is a filesystem path to a Python file and
`reward_functions` is a comma-separated list of function names to import from
it. When both are set, the built-in stack is REPLACED entirely by the
user-provided callables (so users have full control over their reward stack).
Otherwise the default
(`match_format_exactly`, `match_format_approximately`, `check_numbers`) stack
is used. Every reward function is wrapped via `make_reward_fn`.
"""
custom_rewards_path = getattr(trainer_config, "reward_functions_path", "") or ""
custom_rewards_names = getattr(trainer_config, "reward_functions", "") or ""
if custom_rewards_path and custom_rewards_names:
names = [n.strip() for n in custom_rewards_names.split(",") if n.strip()]
reward_fns = [make_reward_fn(utils_rl.load_custom_callable(custom_rewards_path, n)) for n in names]
max_logging.log(f"reward_fns: using {len(reward_fns)} custom reward function(s) {names} from {custom_rewards_path}")
return reward_fns
return [
make_reward_fn(utils_rl.match_format_exactly),
make_reward_fn(utils_rl.match_format_approximately),
make_reward_fn(utils_rl.check_numbers),
]


def create_rl_components(
trainer_config,
sampler_config,
Expand Down Expand Up @@ -526,11 +551,11 @@ def _reward_fn(**kwargs):

return _reward_fn

reward_fns = [ # type: ignore
make_reward_fn(utils_rl.match_format_exactly),
make_reward_fn(utils_rl.match_format_approximately),
make_reward_fn(utils_rl.check_numbers),
]
# Optional user-provided reward functions: when `reward_functions_path` and
# `reward_functions` are both set the built-in stack is replaced entirely by
# the user-provided callables. Each function must accept `prompts`,
# `completions`, `tmvp_config`, and `**kwargs` and return a list of floats.
reward_fns = build_reward_fns(trainer_config, make_reward_fn)

# Create RL trainer
max_logging.log("Setting up RL trainer...")
Expand Down
41 changes: 41 additions & 0 deletions tests/post_training/unit/train_rl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,47 @@ def test_rl_train_invalid_optimizer_memory_host_offload(self, mock_setup):
with self.assertRaisesRegex(ValueError, "optimizer_memory_host_offload=True is not supported"):
train_rl._rl_train_impl([], {}) # pylint: disable=protected-access

@pytest.mark.cpu_only
def test_build_reward_fns_defaults_when_no_custom(self):
"""With neither knob set, the built-in 3-fn stack is returned."""
trainer_config = SimpleNamespace(reward_functions_path="", reward_functions="")
reward_fns = train_rl.build_reward_fns(trainer_config, make_reward_fn=lambda fn: fn)
self.assertEqual(
reward_fns,
[
train_rl.utils_rl.match_format_exactly,
train_rl.utils_rl.match_format_approximately,
train_rl.utils_rl.check_numbers,
],
)

@pytest.mark.cpu_only
def test_build_reward_fns_custom_replaces_builtins(self):
"""When both knobs are set, the stack is the user-provided functions only."""
trainer_config = SimpleNamespace(
reward_functions_path="/tmp/my_rewards.py",
reward_functions="reward_a, reward_b",
)
loaded = {"reward_a": object(), "reward_b": object()}
with mock.patch.object(
train_rl.utils_rl, "load_custom_callable", side_effect=lambda path, name: loaded[name]
) as mock_load:
reward_fns = train_rl.build_reward_fns(trainer_config, make_reward_fn=lambda fn: fn)
self.assertEqual(reward_fns, [loaded["reward_a"], loaded["reward_b"]])
mock_load.assert_has_calls(
[
mock.call("/tmp/my_rewards.py", "reward_a"),
mock.call("/tmp/my_rewards.py", "reward_b"),
]
)

@pytest.mark.cpu_only
def test_build_reward_fns_partial_config_falls_back(self):
"""If only one of the two knobs is set, the built-in stack is used."""
trainer_config = SimpleNamespace(reward_functions_path="/tmp/my_rewards.py", reward_functions="")
reward_fns = train_rl.build_reward_fns(trainer_config, make_reward_fn=lambda fn: fn)
self.assertEqual(len(reward_fns), 3)


if __name__ == "__main__":
unittest.main()
Loading