docs: add reward function customization guide#997
docs: add reward function customization guide#997NJX-njx wants to merge 2 commits intoinclusionAI:mainfrom
Conversation
- Add comprehensive reward customization guide covering: - Reward function signature and conventions - Rule-based reward functions with examples - Built-in reward functions reference table - Integration with training scripts - Generative reward model (LLM-as-Judge) pattern - AsyncRewardWrapper usage for slow reward functions - How to register new built-in reward functions - Add both English (docs/en/customization/reward.md) and Chinese (docs/zh/customization/reward.md) versions - Register new pages in both _toc.yml files Refs: inclusionAI#124, inclusionAI#162
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the AReaL documentation by introducing a detailed guide on customizing reward functions. It provides users with clear instructions and examples for implementing various reward mechanisms, from simple rule-based functions to advanced LLM-as-Judge models, and explains how to integrate them into the training workflow. This update addresses critical documentation gaps and improves the overall usability for developers working with reinforcement learning in AReaL. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Pull request overview
This PR adds comprehensive documentation for reward function customization in AReaL, addressing documentation gaps from issues #124, #162, and the roadmap item from #907. The guide covers the reward function interface, rule-based rewards, built-in rewards, training integration, LLM-as-Judge generative reward models, async handling, and contributing new built-in rewards.
Changes:
- Added English and Chinese reward customization guides covering the full reward function lifecycle (signature, examples, integration, LLM-as-Judge, async wrapper, registration)
- Added the new reward page to both English and Chinese table of contents files, placed between dataset and agent customization
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| docs/en/customization/reward.md | New English guide for reward function customization |
| docs/zh/customization/reward.md | New Chinese guide (translation of English guide) |
| docs/en/_toc.yml | Added customization/reward entry to Customization section |
| docs/zh/_toc.yml | Added customization/reward entry to 定制指南 section |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| | `gsm8k_reward_fn` | `areal.reward.gsm8k.gsm8k_reward_fn` | GSM8K math | | ||
| | `geometry3k_reward_fn` | `areal.reward.geometry3k.geometry3k_reward_fn` | Geometry3K | | ||
| | `clevr_count_70k_reward_fn` | `areal.reward.clevr_count_70k.clevr_count_70k_reward_fn` | CLEVR Count | | ||
|
|
There was a problem hiding this comment.
The built-in reward functions table lists gsm8k_reward_fn as a built-in reward function, but it is NOT registered in VALID_REWARD_FN or handled by get_custom_reward_fn in areal/reward/__init__.py (line 8: VALID_REWARD_FN = ["clevr_count_70k", "geometry3k"]). Unlike geometry3k_reward_fn and clevr_count_70k_reward_fn, gsm8k_reward_fn cannot be auto-selected by dataset name. The table should either note this distinction, or the function should be registered in areal/reward/__init__.py as well.
| Note: `geometry3k_reward_fn` and `clevr_count_70k_reward_fn` are registered in | |
| `areal.reward.__init__.VALID_REWARD_FN` and can be auto-selected based on the dataset | |
| name. `gsm8k_reward_fn` is shipped as a utility but is **not** auto-selected; use its | |
| module path string (`"areal.reward.gsm8k.gsm8k_reward_fn"`) when configuring | |
| `reward_fn`. |
| | `gsm8k_reward_fn` | `areal.reward.gsm8k.gsm8k_reward_fn` | GSM8K 数学 | | ||
| | `geometry3k_reward_fn` | `areal.reward.geometry3k.geometry3k_reward_fn` | Geometry3K | | ||
| | `clevr_count_70k_reward_fn` | `areal.reward.clevr_count_70k.clevr_count_70k_reward_fn` | CLEVR Count | | ||
|
|
There was a problem hiding this comment.
The built-in reward functions table lists gsm8k_reward_fn as a built-in reward function, but it is NOT registered in VALID_REWARD_FN or handled by get_custom_reward_fn in areal/reward/__init__.py (line 8: VALID_REWARD_FN = ["clevr_count_70k", "geometry3k"]). Unlike geometry3k_reward_fn and clevr_count_70k_reward_fn, gsm8k_reward_fn cannot be auto-selected by dataset name. The table should either note this distinction, or the function should be registered in areal/reward/__init__.py as well.
| 注意:`geometry3k_reward_fn` 和 `clevr_count_70k_reward_fn` 会在 `areal.reward.__init__.py` | |
| 中注册到 `VALID_REWARD_FN`,因此可以按数据集名称自动选择;`gsm8k_reward_fn` 目前不会被 | |
| 自动注册或按数据集名称选择,使用时需要像下文示例一样,通过模块路径字符串显式传入 | |
| `reward_fn`。 |
| your reward function for async execution with timeout handling: | ||
|
|
||
| ```python | ||
| from areal.api import AsyncRewardWrapper |
There was a problem hiding this comment.
The import from areal.api import AsyncRewardWrapper will not work because there is no __init__.py in areal/api/. All existing usages in the codebase use from areal.api.reward_api import AsyncRewardWrapper (see e.g., areal/workflow/rlvr.py:12, areal/workflow/multi_turn.py:12). The import path should be corrected to from areal.api.reward_api import AsyncRewardWrapper.
| from areal.api import AsyncRewardWrapper | |
| from areal.api.reward_api import AsyncRewardWrapper |
| 异步执行,并提供超时处理: | ||
|
|
||
| ```python | ||
| from areal.api import AsyncRewardWrapper |
There was a problem hiding this comment.
The import from areal.api import AsyncRewardWrapper will not work because there is no __init__.py in areal/api/. All existing usages in the codebase use from areal.api.reward_api import AsyncRewardWrapper (see e.g., areal/workflow/rlvr.py:12, areal/workflow/multi_turn.py:12). The import path should be corrected to from areal.api.reward_api import AsyncRewardWrapper.
| from areal.api import AsyncRewardWrapper | |
| from areal.api.reward_api import AsyncRewardWrapper |
There was a problem hiding this comment.
Code Review
This pull request adds comprehensive documentation for reward function customization, which is a great addition. The guide is well-structured and covers various aspects from basic rule-based functions to advanced LLM-as-judge models.
My review focuses on the code examples provided in the new documentation files. I've identified a few areas where the examples can be improved to better reflect best practices in terms of performance, robustness, and code style. Specifically, I've suggested changes to:
- Move imports to the top of the file scope.
- Instantiate API clients once and reuse them to avoid inefficiency.
- Add error handling for external API calls to make the code more robust.
These changes will ensure that users who copy the example code will start with a more solid foundation. The same suggestions apply to both the English and Chinese versions of the documentation to maintain consistency.
| import re | ||
|
|
||
| def llm_judge_reward_fn( | ||
| prompt, completions, prompt_ids, completion_ids, **kwargs | ||
| ) -> float: | ||
| """Use an external LLM API to judge response quality.""" | ||
| import openai | ||
|
|
||
| judge_prompt = f"""Rate the following response on a scale of 0 to 10. | ||
| Only output the numeric score. | ||
|
|
||
| Question: {prompt} | ||
| Response: {completions} | ||
|
|
||
| Score:""" | ||
|
|
||
| client = openai.OpenAI( | ||
| base_url="http://localhost:8000/v1", # local vLLM/SGLang server | ||
| api_key="unused", | ||
| ) | ||
| response = client.chat.completions.create( | ||
| model="Qwen/Qwen3-8B", | ||
| messages=[{"role": "user", "content": judge_prompt}], | ||
| temperature=0.0, | ||
| max_tokens=16, | ||
| ) | ||
| score_text = response.choices[0].message.content.strip() | ||
|
|
||
| # Extract numeric score | ||
| match = re.search(r"(\d+(?:\.\d+)?)", score_text) | ||
| if match: | ||
| score = float(match.group(1)) | ||
| return min(score / 10.0, 1.0) # Normalize to [0, 1] | ||
| return 0.0 |
There was a problem hiding this comment.
This example function can be improved in terms of efficiency, robustness, and style.
- Efficiency: A new
openai.OpenAIclient is created on every function call, which is inefficient. The client should be created once and reused. - Style: The
import openaistatement is inside the function. Imports should be at the top of the file. - Robustness: The API call is not wrapped in a
try...exceptblock. Any network or API error will crash the reward calculation.
Here is a revised version that addresses these points:
import re
import openai
# Create the client once at the module level for efficiency.
client = openai.OpenAI(
base_url="http://localhost:8000/v1", # local vLLM/SGLang server
api_key="unused",
)
def llm_judge_reward_fn(
prompt, completions, prompt_ids, completion_ids, **kwargs
) -> float:
"""Use an external LLM API to judge response quality."""
judge_prompt = f"""Rate the following response on a scale of 0 to 10.
Only output the numeric score.
Question: {prompt}
Response: {completions}
Score:"""
try:
response = client.chat.completions.create(
model="Qwen/Qwen3-8B",
messages=[{"role": "user", "content": judge_prompt}],
temperature=0.0,
max_tokens=16,
)
score_text = response.choices[0].message.content.strip()
# Extract numeric score
match = re.search(r"(\d+(?:\.\d+)?)", score_text)
if match:
score = float(match.group(1))
return min(score / 10.0, 1.0) # Normalize to [0, 1]
except Exception as e:
# In a real application, you should log this error.
print(f"LLM-as-judge API call failed: {e}")
return 0.0| import re | ||
|
|
||
| def llm_judge_reward_fn( | ||
| prompt, completions, prompt_ids, completion_ids, **kwargs | ||
| ) -> float: | ||
| """使用外部 LLM API 评判回复质量。""" | ||
| import openai | ||
|
|
||
| judge_prompt = f"""请对以下回复进行评分,评分范围 0-10。 | ||
| 只输出数字分数。 | ||
|
|
||
| 问题:{prompt} | ||
| 回复:{completions} | ||
|
|
||
| 分数:""" | ||
|
|
||
| client = openai.OpenAI( | ||
| base_url="http://localhost:8000/v1", # 本地 vLLM/SGLang 服务 | ||
| api_key="unused", | ||
| ) | ||
| response = client.chat.completions.create( | ||
| model="Qwen/Qwen3-8B", | ||
| messages=[{"role": "user", "content": judge_prompt}], | ||
| temperature=0.0, | ||
| max_tokens=16, | ||
| ) | ||
| score_text = response.choices[0].message.content.strip() | ||
|
|
||
| # 提取数字分数 | ||
| match = re.search(r"(\d+(?:\.\d+)?)", score_text) | ||
| if match: | ||
| score = float(match.group(1)) | ||
| return min(score / 10.0, 1.0) # 归一化到 [0, 1] | ||
| return 0.0 |
There was a problem hiding this comment.
This example function can be improved in terms of efficiency, robustness, and style.
- Efficiency: A new
openai.OpenAIclient is created on every function call, which is inefficient. The client should be created once and reused. - Style: The
import openaistatement is inside the function. Imports should be at the top of the file. - Robustness: The API call is not wrapped in a
try...exceptblock. Any network or API error will crash the reward calculation.
Here is a revised version that addresses these points:
import re
import openai
# Create the client once at the module level for efficiency.
client = openai.OpenAI(
base_url="http://localhost:8000/v1", # 本地 vLLM/SGLang 服务
api_key="unused",
)
def llm_judge_reward_fn(
prompt, completions, prompt_ids, completion_ids, **kwargs
) -> float:
"""使用外部 LLM API 评判回复质量。"""
judge_prompt = f"""请对以下回复进行评分,评分范围 0-10。
只输出数字分数。
问题:{prompt}
回复:{completions}
分数:"""
try:
response = client.chat.completions.create(
model="Qwen/Qwen3-8B",
messages=[{"role": "user", "content": judge_prompt}],
temperature=0.0,
max_tokens=16,
)
score_text = response.choices[0].message.content.strip()
# 提取数字分数
match = re.search(r"(\d+(?:\.\d+)?)", score_text)
if match:
score = float(match.group(1))
return min(score / 10.0, 1.0) # 归一化到 [0, 1]
except Exception as e:
# In a real application, you should log this error.
print(f"LLM-as-judge API call failed: {e}")
return 0.0| import re | ||
|
|
||
| def composite_reward_fn( | ||
| prompt, completions, prompt_ids, completion_ids, answer, **kwargs | ||
| ) -> float: | ||
| """Reward that combines format compliance and answer accuracy.""" | ||
| format_score = 0.0 | ||
| accuracy_score = 0.0 | ||
|
|
||
| # Check if response follows the expected format (e.g., uses \boxed{}) | ||
| if re.search(r"\\boxed\{.+\}", completions): | ||
| format_score = 0.2 | ||
|
|
||
| # Check answer accuracy | ||
| from areal.reward import get_math_verify_worker | ||
| try: | ||
| accuracy_score = get_math_verify_worker().verify( | ||
| str(completions), str(answer) | ||
| ) * 0.8 | ||
| except Exception: | ||
| pass | ||
|
|
||
| return format_score + accuracy_score |
There was a problem hiding this comment.
This code snippet can be improved by moving the import statement to the top of the file scope, which is a standard Python convention (PEP 8). This improves readability and avoids re-importing on every function call.
import re
from areal.reward import get_math_verify_worker
def composite_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
) -> float:
"""Reward that combines format compliance and answer accuracy."""
format_score = 0.0
accuracy_score = 0.0
# Check if response follows the expected format (e.g., uses \boxed{})
if re.search(r"\\boxed\{.+\}", completions):
format_score = 0.2
# Check answer accuracy
try:
accuracy_score = get_math_verify_worker().verify(
str(completions), str(answer)
) * 0.8
except Exception:
pass
return format_score + accuracy_score| import re | ||
|
|
||
| def composite_reward_fn( | ||
| prompt, completions, prompt_ids, completion_ids, answer, **kwargs | ||
| ) -> float: | ||
| """结合格式合规性和答案准确性的奖励函数。""" | ||
| format_score = 0.0 | ||
| accuracy_score = 0.0 | ||
|
|
||
| # 检查回复是否遵循预期格式(如使用 \boxed{}) | ||
| if re.search(r"\\boxed\{.+\}", completions): | ||
| format_score = 0.2 | ||
|
|
||
| # 检查答案准确性 | ||
| from areal.reward import get_math_verify_worker | ||
| try: | ||
| accuracy_score = get_math_verify_worker().verify( | ||
| str(completions), str(answer) | ||
| ) * 0.8 | ||
| except Exception: | ||
| pass | ||
|
|
||
| return format_score + accuracy_score |
There was a problem hiding this comment.
This code snippet can be improved by moving the import statement to the top of the file scope, which is a standard Python convention (PEP 8). This improves readability and avoids re-importing on every function call.
import re
from areal.reward import get_math_verify_worker
def composite_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
) -> float:
"""结合格式合规性和答案准确性的奖励函数。"""
format_score = 0.0
accuracy_score = 0.0
# 检查回复是否遵循预期格式(如使用 \boxed{})
if re.search(r"\\boxed\{.+\}", completions):
format_score = 0.2
# 检查答案准确性
try:
accuracy_score = get_math_verify_worker().verify(
str(completions), str(answer)
) * 0.8
except Exception:
pass
return format_score + accuracy_score
garrett4wade
left a comment
There was a problem hiding this comment.
The new doc is indeed a plus, but it does not fully resolve the issues.
-
For LLM-as-a-judge reward, writing it as a normal synchronous function limits the concurrency. That's why we usually wrap with with the
AsyncRewardWrapperand put the computation to a dedicated process pool for the expensive math reward. However, LLM calls are natively async, so we should use native async rewards instead of wrapping it intoAsyncRewardWrapper. -
I think one of the major concerns in issues is how can we add critic-like rewards, where we utilize a fine-tuned LLM that outputs a scalar value as the reward. It requires heavy GPU computation and cannot be incorporated in the workflow. Instead, reward computation should be added in training loops and the reward model should be offloaded immediately after usage.
Summary
Add comprehensive reward function customization documentation, addressing long-standing documentation gaps from issues #124 and #162. This also covers the Roadmap item \Example of using a generative or critic-like reward model\ from #907.
Changes
New files
Modified files
Guide Contents
Refs: #124, #162, #907