boxed version

This commit is contained in:
Tong Li 2025-04-23 17:20:09 +08:00
parent b823c6eec7
commit 56e4e74140
4 changed files with 11 additions and 10 deletions

View File

@ -375,7 +375,7 @@ def apply_chat_template_and_mask(
tokens = [] tokens = []
assistant_mask = [] assistant_mask = []
for i, msg in enumerate(chat): for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True) msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True, add_generation_prompt=True)
# remove unexpected bos token # remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id: if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:] msg_tokens = msg_tokens[1:]

View File

@ -18,12 +18,12 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer) final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"]) # format_valid = validate_response_structure(processed_str, kwargs["tags"])
# Check format accuracy # # Check format accuracy
if format_valid: # if format_valid:
format_reward += format_score # format_reward += format_score
reward += format_score # reward += format_score
# Check answer accuracy # Check answer accuracy
if ( if (

View File

@ -66,7 +66,8 @@ def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
""" """
# Extract final answer using XML-style tags # Extract final answer using XML-style tags
answer_pattern = r"<answer>(.*?)</answer>" # answer_pattern = r"<answer>(.*?)</answer>"
answer_pattern = r"boxed{(.*?)}"
matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL)) matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
if not matches: if not matches:

View File

@ -44,7 +44,7 @@ if __name__ == "__main__":
"-tmbs", "-tmbs",
"--train-microbatch-size", "--train-microbatch-size",
type=int, type=int,
default=2, default=1,
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
) )
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
@ -84,7 +84,7 @@ if __name__ == "__main__":
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True)) inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
generate_config.update( generate_config.update(
dict( dict(
max_tokens=2048, max_tokens=4096,
ignore_eos=True, ignore_eos=True,
include_stop_str_in_output=True, include_stop_str_in_output=True,
stop=["</answer>"], stop=["</answer>"],
@ -107,7 +107,7 @@ if __name__ == "__main__":
num_producers=args.num_inferencer, num_producers=args.num_inferencer,
num_proc_per_producer=1, num_proc_per_producer=1,
num_consumer_procs=args.num_trainers, num_consumer_procs=args.num_trainers,
num_episodes=1, num_episodes=2,
inference_batch_size=args.inference_batch_size, inference_batch_size=args.inference_batch_size,
inference_microbatch_size=args.inference_microbatch_size, inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,