add algo selection

This commit is contained in:
Tong Li 2025-03-06 14:29:22 +08:00
parent 812f4b7750
commit 0f566cc2d4
2 changed files with 16 additions and 1 deletions

View File

@ -2,9 +2,15 @@ from typing import Any, Dict, Optional
import ray
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer
ALGO_MAP = {
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
}
def get_jsonl_size_fast(path: str) -> int:
with open(path) as f:
@ -40,7 +46,14 @@ def launch_distributed(
inference_backend: str = "transformers",
master_addr: str = "localhost",
master_port: int = 29500,
core_algo: str = "GRPO",
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
else:
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
@ -68,7 +81,7 @@ def launch_distributed(
)
procs.append(producer)
for i in range(num_consumer_procs):
consumer = GRPOConsumer.options(num_gpus=1).remote(
consumer = core_consumer.options(num_gpus=1).remote(
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,

View File

@ -15,6 +15,7 @@ if __name__ == "__main__":
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
parser.add_argument("-b", "--backend", type=str, default="transformers")
parser.add_argument("-a", "--algo", type=str, default="GPRO", choices=["Simple, GPRO"])
args = parser.parse_args()
ray.init(address="local", namespace="ray-example")
@ -95,4 +96,5 @@ if __name__ == "__main__":
inference_backend=args.backend,
master_addr="localhost",
master_port=29504,
core_algo=args.algo
)