add algo selection

This commit is contained in:
Tong Li
2025-03-06 14:29:22 +08:00
parent 812f4b7750
commit 0f566cc2d4
2 changed files with 16 additions and 1 deletions

View File

@@ -2,9 +2,15 @@ from typing import Any, Dict, Optional
import ray
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer
ALGO_MAP = {
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
}
def get_jsonl_size_fast(path: str) -> int:
with open(path) as f:
@@ -40,7 +46,14 @@ def launch_distributed(
inference_backend: str = "transformers",
master_addr: str = "localhost",
master_port: int = 29500,
core_algo: str = "GRPO",
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
else:
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
@@ -68,7 +81,7 @@ def launch_distributed(
)
procs.append(producer)
for i in range(num_consumer_procs):
consumer = GRPOConsumer.options(num_gpus=1).remote(
consumer = core_consumer.options(num_gpus=1).remote(
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,