From 0f566cc2d49657fec09c34aa50389f951f55103b Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 14:29:22 +0800 Subject: [PATCH] add algo selection --- .../ColossalChat/coati/distributed/launch.py | 15 ++++++++++++++- applications/ColossalChat/rl_example.py | 2 ++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 5244cc7d9..8581ff586 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -2,9 +2,15 @@ from typing import Any, Dict, Optional import ray +from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer from .producer import SimpleProducer +ALGO_MAP = { + "Simple": SimpleConsumer, + "GRPO": GRPOConsumer, +} + def get_jsonl_size_fast(path: str) -> int: with open(path) as f: @@ -40,7 +46,14 @@ def launch_distributed( inference_backend: str = "transformers", master_addr: str = "localhost", master_port: int = 29500, + core_algo: str = "GRPO", ): + + if core_algo not in ALGO_MAP: + raise NotImplementedError(f"{core_algo} is not supported yet.") + else: + core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer) + train_dp_size = get_dp_size_fast(num_producers, plugin_config) assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 @@ -68,7 +81,7 @@ def launch_distributed( ) procs.append(producer) for i in range(num_consumer_procs): - consumer = GRPOConsumer.options(num_gpus=1).remote( + consumer = core_consumer.options(num_gpus=1).remote( num_producers=num_producers, num_episodes=num_episodes, rank=i, diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 57cab4164..40231582d 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -15,6 +15,7 @@ if __name__ == "__main__": parser.add_argument("-tbs", "--train-batch-size", type=int, default=16) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) parser.add_argument("-b", "--backend", type=str, default="transformers") + parser.add_argument("-a", "--algo", type=str, default="GPRO", choices=["Simple, GPRO"]) args = parser.parse_args() ray.init(address="local", namespace="ray-example") @@ -95,4 +96,5 @@ if __name__ == "__main__": inference_backend=args.backend, master_addr="localhost", master_port=29504, + core_algo=args.algo )