This commit is contained in:
YeAnbang
2025-03-19 17:07:20 +08:00
parent 7795d4c50d
commit 7ee4452f8c
5 changed files with 172 additions and 24 deletions

View File

@@ -1,15 +1,13 @@
import copy
from typing import Any, Dict, Optional
import ray
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
from .producer import SimpleProducer
ALGO_MAP = {
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
}
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
def get_jsonl_size_fast(path: str) -> int:
@@ -80,6 +78,12 @@ def launch_distributed(
backend=inference_backend,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update(
dict(
backend=inference_backend,
)
)
for i in range(num_consumer_procs):
consumer = core_consumer.options(num_gpus=1).remote(
num_producers=num_producers,
@@ -94,6 +98,8 @@ def launch_distributed(
model_config=train_model_config,
plugin_config=plugin_config,
microbatch_size=train_microbatch_size,
generate_config=generate_config_consumer,
filter_range=[0.05, 9.0],
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])