mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
add algo selection
This commit is contained in:
@@ -2,9 +2,15 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
|
||||
from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer
|
||||
from .producer import SimpleProducer
|
||||
|
||||
ALGO_MAP = {
|
||||
"Simple": SimpleConsumer,
|
||||
"GRPO": GRPOConsumer,
|
||||
}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
with open(path) as f:
|
||||
@@ -40,7 +46,14 @@ def launch_distributed(
|
||||
inference_backend: str = "transformers",
|
||||
master_addr: str = "localhost",
|
||||
master_port: int = 29500,
|
||||
core_algo: str = "GRPO",
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||
else:
|
||||
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
||||
|
||||
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
|
||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||
|
||||
@@ -68,7 +81,7 @@ def launch_distributed(
|
||||
)
|
||||
procs.append(producer)
|
||||
for i in range(num_consumer_procs):
|
||||
consumer = GRPOConsumer.options(num_gpus=1).remote(
|
||||
consumer = core_consumer.options(num_gpus=1).remote(
|
||||
num_producers=num_producers,
|
||||
num_episodes=num_episodes,
|
||||
rank=i,
|
||||
|
Reference in New Issue
Block a user