mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-27 10:51:58 +00:00
add custom agentic producer
This commit is contained in:
@@ -4,8 +4,9 @@ import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
from coati.distributed.agent.langgraph_math_agentic import LangGraphMathAgenticProducer
|
||||
from coati.distributed.agent.qwen_math_agentic import QwenMathAgenticProducer
|
||||
from coati.distributed.agent.agentic_producer import AgenticProducer
|
||||
from coati.distributed.agent.qwen_math_agentic_producer import QwenMathAgenticProducer
|
||||
from coati.distributed.agent.tool_worker import ToolWorker
|
||||
|
||||
from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer
|
||||
@@ -21,7 +22,7 @@ ALGO_MAP = {
|
||||
Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncSimpleProducer}
|
||||
AGENTIC_PRODUCER_MAP = {
|
||||
"QwenMathAgent": QwenMathAgenticProducer,
|
||||
"LangGraphMathAgent": LangGraphMathAgenticProducer,
|
||||
"Agentic": AgenticProducer,
|
||||
} # supported agentic producers
|
||||
|
||||
|
||||
@@ -165,21 +166,16 @@ def launch_distributed(
|
||||
)
|
||||
producer_procs.append(producer)
|
||||
ray.get([p.setup.remote() for p in producer_procs])
|
||||
"""
|
||||
# test async generate
|
||||
import torch
|
||||
import asyncio
|
||||
import time
|
||||
async def test():
|
||||
res_ref = producer_procs[0].generate.remote(torch.ones((2, 10), dtype=torch.int), torch.ones((2, 10), dtype=torch.int))
|
||||
res = await res_ref
|
||||
return res
|
||||
res = asyncio.run(test())
|
||||
print(res)
|
||||
time.sleep(1000)
|
||||
"""
|
||||
|
||||
if enable_agentic:
|
||||
from coati.distributed.agent.math_tools import repl_tool
|
||||
|
||||
# setup tool workers
|
||||
tool_workers = []
|
||||
if agentic_config["agentic_producer"] == "Agentic":
|
||||
# 10 tool workers can handle 50 queries simultaneously
|
||||
# note that imported repl_tool will be serialized and deserialized in each tool worker, therefore all workers can run parallely
|
||||
tool_workers = [ToolWorker.remote([repl_tool]) for _ in range(agentic_config.get("num_tool_workers", 10))]
|
||||
# when agentic is enabled, we use core_producer as inference engine and
|
||||
# AgenticProducer as the real producer
|
||||
_producer_procs = producer_procs
|
||||
@@ -194,7 +190,7 @@ def launch_distributed(
|
||||
producer_procs = [
|
||||
agentic_producer_cls.options(num_cpus=1).remote(
|
||||
producer_idx=producer_idx,
|
||||
num_producers=num_producers * train_batch_size,
|
||||
num_producers=num_producers * inference_batch_size,
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
num_episodes=num_episodes,
|
||||
batch_size=1, # batch_size must be 1 for agentic producer
|
||||
@@ -202,6 +198,7 @@ def launch_distributed(
|
||||
model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
async_producers=_producer_procs,
|
||||
tool_workers=tool_workers,
|
||||
tokenizer_config=tokenizer_config,
|
||||
agentic_config=agentic_config,
|
||||
microbatch_size=1, # microbatch_size must be 1 for agentic producer
|
||||
|
||||
Reference in New Issue
Block a user