fix schedualing for multi-node training

This commit is contained in:
YeAnbang
2025-05-02 19:45:07 +08:00
parent d06042b434
commit 7d658402da
7 changed files with 124 additions and 38 deletions

View File

@@ -6,6 +6,13 @@ import ray
import torch
from coati.distributed.launch import launch_distributed
DEFAULT_RESPONSE_FORMAT_TAGS = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
@@ -65,10 +72,22 @@ if __name__ == "__main__":
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
)
parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
"--master_address",
type=str,
default=None,
help="Ray master address for multi-node distributed training, Optional",
)
parser.add_argument(
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
"--torch_ddp_master_address",
type=str,
default=None,
help="Torch DDP master address for multi-node distributed training, Optional",
)
parser.add_argument(
"--torch_ddp_master_port",
type=int,
default=29505,
help="Torch DDP master port for multi-node distributed training, Optional",
)
# Sampling parameters
@@ -105,6 +124,9 @@ if __name__ == "__main__":
help="Reward type for GRPO.",
)
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
parser.add_argument(
"-rft", "--reponse_format-tags", type=str, default=None, help="Optional json string of the response format tag"
)
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@@ -236,17 +258,17 @@ if __name__ == "__main__":
"zero_stage": 2,
}, # for zero
# plugin_config={
# "tp_size": 1,
# "pp_size": 2,
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 0,
# "max_norm": 1.0,
# "tp_size": 4,
# "pp_size": 2,
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 1,
# "max_norm": 1.0,
# }, # for pp, tp
inference_backend=args.backend,
master_addr="localhost",
master_port=args.master_port,
master_addr=args.torch_ddp_master_address,
master_port=args.torch_ddp_master_port,
core_algo=args.algo,
project_name=args.project,
save_interval=args.save_interval,
@@ -257,4 +279,9 @@ if __name__ == "__main__":
},
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
response_format_tags=(
json.loads(args.response_format_tags)
if args.response_format_tags is not None
else DEFAULT_RESPONSE_FORMAT_TAGS
),
)