mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
fix reward taging bug
This commit is contained in:
@@ -6,13 +6,6 @@ import ray
|
||||
import torch
|
||||
from coati.distributed.launch import launch_distributed
|
||||
|
||||
DEFAULT_RESPONSE_FORMAT_TAGS = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
@@ -248,18 +241,18 @@ if __name__ == "__main__":
|
||||
num_generations=args.num_generations,
|
||||
train_model_config=train_model_config,
|
||||
grpo_config=grpo_config,
|
||||
plugin_config={
|
||||
"zero_stage": 2,
|
||||
}, # for zero
|
||||
# plugin_config={
|
||||
# "tp_size": 4,
|
||||
# "pp_size": 2,
|
||||
# "microbatch_size": max(
|
||||
# 1, args.train_microbatch_size // 2
|
||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
# "zero_stage": 1,
|
||||
# "max_norm": 1.0,
|
||||
# }, # for pp, tp
|
||||
# "zero_stage": 2,
|
||||
# }, # for zero
|
||||
plugin_config={
|
||||
"tp_size": 4,
|
||||
"pp_size": 2,
|
||||
"microbatch_size": max(
|
||||
1, args.train_microbatch_size // 2
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
"zero_stage": 1,
|
||||
"max_norm": 1.0,
|
||||
}, # for pp, tp
|
||||
inference_backend=args.backend,
|
||||
master_port=args.torch_ddp_master_port,
|
||||
core_algo=args.algo,
|
||||
@@ -272,9 +265,5 @@ if __name__ == "__main__":
|
||||
},
|
||||
eval_interval=args.eval_interval,
|
||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||
response_format_tags=(
|
||||
json.loads(args.response_format_tags)
|
||||
if args.response_format_tags is not None
|
||||
else DEFAULT_RESPONSE_FORMAT_TAGS
|
||||
),
|
||||
response_format_tags=(json.loads(args.response_format_tags) if args.response_format_tags is not None else {}),
|
||||
)
|
||||
|
Reference in New Issue
Block a user