add custom agentic producer

This commit is contained in:
YeAnbang
2025-09-16 16:23:46 +08:00
parent 62f82a75ae
commit edcef9edaf
12 changed files with 482 additions and 378 deletions

View File

@@ -4,8 +4,9 @@ import uuid
from typing import Any, Dict, Optional
import ray
from coati.distributed.agent.langgraph_math_agentic import LangGraphMathAgenticProducer
from coati.distributed.agent.qwen_math_agentic import QwenMathAgenticProducer
from coati.distributed.agent.agentic_producer import AgenticProducer
from coati.distributed.agent.qwen_math_agentic_producer import QwenMathAgenticProducer
from coati.distributed.agent.tool_worker import ToolWorker
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
@@ -21,7 +22,7 @@ ALGO_MAP = {
Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncSimpleProducer}
AGENTIC_PRODUCER_MAP = {
"QwenMathAgent": QwenMathAgenticProducer,
"LangGraphMathAgent": LangGraphMathAgenticProducer,
"Agentic": AgenticProducer,
} # supported agentic producers
@@ -165,21 +166,16 @@ def launch_distributed(
)
producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
"""
# test async generate
import torch
import asyncio
import time
async def test():
res_ref = producer_procs[0].generate.remote(torch.ones((2, 10), dtype=torch.int), torch.ones((2, 10), dtype=torch.int))
res = await res_ref
return res
res = asyncio.run(test())
print(res)
time.sleep(1000)
"""
if enable_agentic:
from coati.distributed.agent.math_tools import repl_tool
# setup tool workers
tool_workers = []
if agentic_config["agentic_producer"] == "Agentic":
# 10 tool workers can handle 50 queries simultaneously
# note that imported repl_tool will be serialized and deserialized in each tool worker, therefore all workers can run parallely
tool_workers = [ToolWorker.remote([repl_tool]) for _ in range(agentic_config.get("num_tool_workers", 10))]
# when agentic is enabled, we use core_producer as inference engine and
# AgenticProducer as the real producer
_producer_procs = producer_procs
@@ -194,7 +190,7 @@ def launch_distributed(
producer_procs = [
agentic_producer_cls.options(num_cpus=1).remote(
producer_idx=producer_idx,
num_producers=num_producers * train_batch_size,
num_producers=num_producers * inference_batch_size,
num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes,
batch_size=1, # batch_size must be 1 for agentic producer
@@ -202,6 +198,7 @@ def launch_distributed(
model_config=inference_model_config,
generate_config=generate_config,
async_producers=_producer_procs,
tool_workers=tool_workers,
tokenizer_config=tokenizer_config,
agentic_config=agentic_config,
microbatch_size=1, # microbatch_size must be 1 for agentic producer