mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 19:55:29 +00:00
boxed version
This commit is contained in:
parent
b823c6eec7
commit
56e4e74140
@ -375,7 +375,7 @@ def apply_chat_template_and_mask(
|
|||||||
tokens = []
|
tokens = []
|
||||||
assistant_mask = []
|
assistant_mask = []
|
||||||
for i, msg in enumerate(chat):
|
for i, msg in enumerate(chat):
|
||||||
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
|
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True, add_generation_prompt=True)
|
||||||
# remove unexpected bos token
|
# remove unexpected bos token
|
||||||
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
|
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
|
||||||
msg_tokens = msg_tokens[1:]
|
msg_tokens = msg_tokens[1:]
|
||||||
|
@ -18,12 +18,12 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||||
|
|
||||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
# format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||||
|
|
||||||
# Check format accuracy
|
# # Check format accuracy
|
||||||
if format_valid:
|
# if format_valid:
|
||||||
format_reward += format_score
|
# format_reward += format_score
|
||||||
reward += format_score
|
# reward += format_score
|
||||||
|
|
||||||
# Check answer accuracy
|
# Check answer accuracy
|
||||||
if (
|
if (
|
||||||
|
@ -66,7 +66,8 @@ def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Extract final answer using XML-style tags
|
# Extract final answer using XML-style tags
|
||||||
answer_pattern = r"<answer>(.*?)</answer>"
|
# answer_pattern = r"<answer>(.*?)</answer>"
|
||||||
|
answer_pattern = r"boxed{(.*?)}"
|
||||||
matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
|
matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
|
@ -44,7 +44,7 @@ if __name__ == "__main__":
|
|||||||
"-tmbs",
|
"-tmbs",
|
||||||
"--train-microbatch-size",
|
"--train-microbatch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=1,
|
||||||
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||||
)
|
)
|
||||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||||
@ -84,7 +84,7 @@ if __name__ == "__main__":
|
|||||||
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=2048,
|
max_tokens=4096,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"],
|
stop=["</answer>"],
|
||||||
@ -107,7 +107,7 @@ if __name__ == "__main__":
|
|||||||
num_producers=args.num_inferencer,
|
num_producers=args.num_inferencer,
|
||||||
num_proc_per_producer=1,
|
num_proc_per_producer=1,
|
||||||
num_consumer_procs=args.num_trainers,
|
num_consumer_procs=args.num_trainers,
|
||||||
num_episodes=1,
|
num_episodes=2,
|
||||||
inference_batch_size=args.inference_batch_size,
|
inference_batch_size=args.inference_batch_size,
|
||||||
inference_microbatch_size=args.inference_microbatch_size,
|
inference_microbatch_size=args.inference_microbatch_size,
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user