add custom agentic producer

This commit is contained in:
YeAnbang
2025-09-16 16:23:46 +08:00
parent 62f82a75ae
commit edcef9edaf
12 changed files with 482 additions and 378 deletions

View File

@@ -0,0 +1,287 @@
import copy
import random
import re
from typing import Any, Dict
from uuid import uuid4
import ray
from coati.distributed.agent.base import BaseAgenticProducer
from transformers import AutoTokenizer
DEFAULT_SYSTEM_MESSAGE = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <reason> </reason> and <answer> </answer> tags, respectively, i.e., <reason> reasoning process here </reason><answer> answer here </answer>."""
@ray.remote
class AgenticProducer(BaseAgenticProducer):
"""
Asyncronous version of the producer that uses vLLM for generation.
This class is designed to generate agentic response
Please use the following SYSTEM message or a similar one for the agentic math model:
'''A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The Assistant first thinks about the reasoning process in the mind and then provides the user with
the answer. The reasoning process and answer are enclosed within <reason> </reason> and <answer>
</answer> tags, respectively, i.e., <reason> reasoning process here </reason><answer> answer here </answer>.'''
"""
def __init__(
self,
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
async_producers,
tool_workers=[],
tokenizer_config=None,
agentic_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
n_behind: int = 0,
):
assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
assert batch_size == 1 # batch_size must be 1 for agentic producer
super().__init__(
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
async_producers,
tokenizer_config,
microbatch_size,
backend,
num_generations,
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
self.tool_workers = tool_workers
self.agentic_config = model_config if not agentic_config else agentic_config
self.agentic_config.update({"model": model_config["path"]})
tokenizer_path = None
if tokenizer_config and "path" in tokenizer_config:
tokenizer_path = tokenizer_config["path"]
elif "path" in model_config:
tokenizer_path = model_config["path"]
assert tokenizer_path is not None, "Tokenizer path must be provided either in tokenizer_config or model_config."
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
self.tools_schema = []
self.tool_call_budget = self.agentic_config.get("tool_call_budget", 3)
self.llm_call_budget = self.agentic_config.get("llm_call_budget", 10)
self.async_llm_engine_map = {}
self._get_tools()
def _get_tools(self):
"""
SYSTEM message for the agentic math model. Reference: r-start2 paper https://arxiv.org/pdf/2508.20722
"""
tools = ray.get(self.tool_workers[0].list_tools.remote())
tool_descriptions = {tool: ray.get(self.tool_workers[0].get_tool_description.remote(tool)) for tool in tools}
tool_arg_schemas = {tool: ray.get(self.tool_workers[0].get_args_schema.remote(tool)) for tool in tools}
self.tools = []
for tool in tools:
tool_schema = {"name": tool, "description": tool_descriptions[tool], "parameters": tool_arg_schemas[tool]}
self.tools.append(tool_schema)
def _build_prompt(
self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt"
) -> dict:
"""
Build the prompt for the agentic math model.
"""
return self.tokenizer.apply_chat_template(
messages,
tools=self.tools,
add_generation_prompt=add_generation_prompt,
return_dict=return_dict,
return_tensors=return_tensors,
)
def _parse_response(self, response: str) -> Dict[str, Any]:
"""
Parse the response from the agentic math model.
Sample Assistant Response:
The tool indicates that Singapores weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}<|im_end|>
Sample Assistant Response with Tool Call:
To answer this, I will check both the weather and the timezone for New York.\n<tool_call>\n{"name": "get_weather", "arguments": {"location": "New York"}}\n</tool_call>\n<tool_call>\n{"name": "get_timezone", "arguments": {"location": "New York"}}\n</tool_call>
Sample Ouput:
{
"role": "assistant",
"content": "Let me check the current weather in Singapore by calling the weather tool.",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {
"location": "New York"
}
}
},
{
"function": {
"name": "get_timezone",
"arguments": {
"location": "New York"
}
}
}
]
},
{
"role": "assistant",
"content": "The tool indicates that Singapores weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}"
}
"""
# split by <im_end|>
response_chunked = response.split("<|im_end|>")[0].strip()
if "<tool_call>" in response_chunked:
assistant_content = response_chunked.split("<tool_call>")[0].strip()
tool_call_sections = response_chunked[response_chunked.find("<tool_call>") :].strip()
# extract all tool calls
tool_calls = []
pattern = "<tool_call>(.*?)</tool_call>"
matches = re.findall(pattern, tool_call_sections, re.DOTALL)
for match in matches:
try:
tool_call = eval(match.strip())
name = tool_call["name"]
arguments = tool_call["arguments"]
tool_calls.append({"function": {"name": name, "arguments": arguments}})
except Exception as e:
print(f"Failed to parse tool call: {match.strip()}. Error: {e}")
tool_calls.append({"function": {"name": "return_parsing_error", "arguments": {}}})
else:
assistant_content = response_chunked
tool_calls = []
assistant_message = {"role": "assistant", "content": assistant_content}
if tool_calls:
assistant_message["tool_calls"] = tool_calls
return assistant_message
def _select_tool_worker(self) -> ray.actor.ActorHandle:
"""
Select a tool worker based on the current load.
"""
loads = ray.get([worker.get_load.remote() for worker in self.tool_workers])
min_load = min(loads)
candidates = [i for i, l in enumerate(loads) if l == min_load]
selected_idx = random.choice(candidates) # random tie break
ray.get(self.tool_workers[selected_idx].increase_load.remote())
return self.tool_workers[selected_idx]
def _select_async_producer(self, request_id) -> ray.actor.ActorHandle:
"""
Select an async producer based on the current load.
"""
# use the last used async producer if exists to reuse kv cache (as vllm use paged kv cache,
# it will reuse most of the kv cache pages without recomputation)
if request_id in self.async_llm_engine_map:
return self.async_producers[self.async_llm_engine_map[request_id]]
# otherwise select the least loaded async producer
loads = ray.get([proc.get_producer_load.remote() for proc in self.async_producers])
min_load = min(loads)
candidates = [i for i, l in enumerate(loads) if l == min_load]
selected_idx = random.choice(candidates) # random tie break
self.async_llm_engine_map[request_id] = selected_idx
return self.async_producers[selected_idx]
def _run_agentic_pipeline(self, messages):
"""
Run the agentic pipeline to generate responses based on the input messages.
"""
tool_call_count = 0
llm_call_count = 0
num_prompt_tokens = 0
request_id = str(uuid4())
logprobs = None
while True:
# tokenize the messages
if llm_call_count > self.llm_call_budget:
print(f"LLM call budget exceeded: {llm_call_count} > {self.llm_call_budget}. Stopping.")
del self.async_llm_engine_map[request_id]
while messages[-1]["role"] == "tool":
messages.pop()
return messages, logprobs
inputs = self._build_prompt(messages, return_dict=True, return_tensors="pt")
if num_prompt_tokens == 0:
num_prompt_tokens = inputs["input_ids"].size(-1)
if inputs["input_ids"].size(-1) - num_prompt_tokens > self.generate_config["max_tokens"]:
print(
f"Max tokens exceeded: Current have generated {inputs['input_ids'].size(-1) - num_prompt_tokens} tokens > {self.generate_config.get('max_tokens', 512)}. Stopping."
)
del self.async_llm_engine_map[request_id]
while messages[-1]["role"] == "tool":
messages.pop()
return messages, logprobs
async_producer = self._select_async_producer(request_id=request_id)
agentic_generate_config = copy.deepcopy(self.generate_config)
agentic_generate_config["max_tokens"] = self.agentic_config.get("max_tokens", 2048)
response = ray.get(
async_producer.generate.remote(
inputs["input_ids"],
inputs["attention_mask"],
request_id=request_id,
**agentic_generate_config,
)
)
llm_call_count += 1
response_input_ids = response["input_ids"]
logprobs = response["action_log_probs"]
response_text = self.tokenizer.decode(
response_input_ids[0][0][inputs["input_ids"].size(-1) :], skip_special_tokens=False
)
assistant_message = self._parse_response(response_text)
messages.append(assistant_message)
if "tool_calls" in assistant_message:
if tool_call_count > self.tool_call_budget:
print(f"Tool call budget exceeded: {tool_call_count} > {self.tool_call_budget}. Stopping.")
del self.async_llm_engine_map[request_id]
return messages, logprobs
tool_call_count += len(assistant_message["tool_calls"])
handlers = []
for tool_call in assistant_message["tool_calls"]:
# select a tool worker to execute the tool call
tool_worker = self._select_tool_worker()
handler = tool_worker.call.remote(tool_call["function"]["name"], tool_call["function"]["arguments"])
handlers.append(handler)
tool_results = ray.get(handlers)
for tool_call, tool_result in zip(assistant_message["tool_calls"], tool_results):
tool_message = {"role": "tool", "content": str(tool_result)}
messages.append(tool_message)
else:
# no further tool call, return the messages
del self.async_llm_engine_map[request_id]
return messages, logprobs

View File

@@ -1,5 +1,6 @@
import copy
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict
import ray
@@ -86,6 +87,15 @@ class BaseAgenticProducer(BaseProducer):
"""
raise NotImplementedError
def _build_prompt(
self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt"
) -> dict:
"""
Build the prompt from the input messages.
This function should be implemented in subclasses.
"""
raise NotImplementedError
def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
"""
Rollout function to generate responses for the input, for example, using LLM or agentic pipeline.
@@ -93,9 +103,9 @@ class BaseAgenticProducer(BaseProducer):
"""
assert len(kwargs["messages"]) == 1, "Only support batch size of 1 for agentic producer"
messages = kwargs["messages"][0]
prompt_input_ids = self.tokenizer.apply_chat_template(
messages, return_tensors="pt", tokenize=True, add_generation_prompt=True
)
prompt_input_ids = self._build_prompt(
messages, return_dict=True, return_tensors="pt", add_generation_prompt=True
)["input_ids"]
# add left padding
prompt_length = prompt_input_ids.shape[1]
max_prompt_length = self.train_dataset_config["max_length"]
@@ -107,10 +117,16 @@ class BaseAgenticProducer(BaseProducer):
"action_log_probs": [],
"response_idx": [],
}
with ThreadPoolExecutor(max_workers=self.num_generations) as executor:
results = list(
executor.map(self._run_agentic_pipeline, [copy.deepcopy(messages) for _ in range(self.num_generations)])
)
for i in range(self.num_generations):
_messages = copy.deepcopy(messages)
_messages = self._run_agentic_pipeline(_messages)
response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True)
_messages, logprobs = results[i]
response_input_ids = self._build_prompt(
_messages, return_dict=True, return_tensors="pt", add_generation_prompt=False
)["input_ids"]
# truncate if too long
response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left]
# add left right padding
@@ -127,9 +143,14 @@ class BaseAgenticProducer(BaseProducer):
) # [1, max_length-prompt_length]
rollouts["attention_mask"].append(attention_mask)
rollouts["action_mask"].append(action_mask)
rollouts["action_log_probs"].append(
torch.ones(size=(1, self.grpo_config["max_length"] - max_prompt_length))
) # dummy log probs
truncated_logprobs = logprobs[:, :, prompt_length : prompt_length + self.generate_config["max_tokens"]]
logprobs_padded = torch.nn.functional.pad(
truncated_logprobs,
(0, self.generate_config["max_tokens"] - truncated_logprobs.size(-1)),
"constant",
value=0.0,
) # [1, max_new_tokens]
rollouts["action_log_probs"].append(logprobs_padded[0])
rollouts["response_idx"].append(
torch.tensor(
[
@@ -141,7 +162,6 @@ class BaseAgenticProducer(BaseProducer):
)
) # [1, 2]
rollouts["input_ids"].append(input_ids)
# breakpoint()
rollouts = {k: torch.cat(v, dim=0).unsqueeze(0) for k, v in rollouts.items()} # [num_generations, ...]
rollouts["temperature"] = torch.tensor([self.agentic_config.get("temperature", 1.0)])
if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode:

View File

@@ -1,122 +0,0 @@
from typing import Any, Dict
import ray
from coati.distributed.agent.agentic import BaseAgenticProducer
from coati.distributed.agent.langgraph_math_agentic_utils import CustomOpenAIAPILLM, LangChainCustomLLM, python
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
@ray.remote
class LangGraphMathAgenticProducer(BaseAgenticProducer):
"""
Asyncronous version of the producer that uses vLLM for generation.
This class is designed to generate agentic response
"""
def __init__(
self,
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
async_producers,
tokenizer_config=None,
agentic_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
n_behind: int = 0,
):
assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
assert batch_size == 1 # batch_size must be 1 for agentic producer
super().__init__(
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
async_producers,
tokenizer_config,
microbatch_size,
backend,
num_generations,
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
self.agentic_config = agentic_config
self.agentic_config.pop("agentic_type", None)
self.llm_client = CustomOpenAIAPILLM({"model": model_config["path"]}, producer_idx, self.async_producers)
self.llm = LangChainCustomLLM(self.llm_client)
# self.python_repl = PythonREPL()
# repl_tool = Tool(
# name="python_repl",
# description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.",
# func=self.python_repl.run,
# )
# self.tools = [repl_tool]
self.tools = [python]
self.memory = MemorySaver()
self.bot = create_react_agent(self.llm, self.tools, checkpointer=self.memory)
def _run_agentic_pipeline(self, messages):
"""
Run the agentic pipeline to generate responses based on the input messages using the LangGraph.
"""
assert (
len(messages) == 2 and messages[0]["role"] == "system" and messages[1]["role"] == "user"
), "Only support 1 system message and 1 user message as input."
# inputs = messages
for event in self.bot.stream(
{"messages": [("system", messages[0]["content"]), ("user", "calculate the 1000th Fibonacci number")]},
self.agentic_config,
):
continue
breakpoint()
final_state = self.bot.get_state(self.agentic_config)
transformer_messages = []
for message in final_state[0]["messages"]:
tool_calls = None
if isinstance(message, SystemMessage):
message.content
elif isinstance(message, HumanMessage):
message.content
elif isinstance(message, AIMessage):
message.content
tool_calls = message.get("tool_calls", None) # [{"type": "function", "function": tool_call}]
elif isinstance(message, ToolMessage):
message.content
return transformer_messages

View File

@@ -1,185 +0,0 @@
# -------------------------------
# 1. Define the Python tool
# -------------------------------
import copy
import io
import random
import sys
from typing import Dict, List
import ray
import torch
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.chat_result import ChatResult
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from tool_calling_llm import ToolCallingLLM
from transformers import AutoTokenizer
SYSTEM_PROMPT_TEMPLATE = """{task_description}. You have access to the following tools:
{tools}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
Question: {input}
Thought:{agent_scratchpad}"""
SYSTEM_PROMPT = PromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE)
class Capturing(list):
"""Capture stdout prints inside exec()"""
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = io.StringIO()
return self
def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
sys.stdout = self._stdout
@tool
def python(code: str) -> str:
"""
This function executes a string of Python code and returns the printed output.
You need to print the output. Please import all libraries used in the code string.
"""
local_vars = {}
with Capturing() as output:
exec(code, {}, local_vars)
if output == []:
return "Error: No output printed from the code. Please ensure you print the output."
return "\n".join(output)
# -------------------------------
# 2. Define a Custom API LLM wrapper
# -------------------------------
class CustomOpenAIAPILLM:
def __init__(self, cfg: dict, producer_idx, generation_workers=None):
self.producer_idx = producer_idx
self.generation_workers = generation_workers
self.load_balancer_idx = producer_idx % len(self.generation_workers)
assert "model" in cfg, "Please specify the model name in the config"
self.tokenizer = AutoTokenizer.from_pretrained(cfg["model"])
self.role_mapping = {
"system": "system",
"user": "user",
"assistant": "assistant",
"human": "user",
"tool": "tool",
}
def invoke(self, messages: List[Dict[str, str]], **kwargs) -> str:
"""
messages: list of {"role": "user"/"assistant"/"system", "content": "..."}
"""
# load balancing
load = [ray.get(generation_worker.get_producer_load.remote()) for generation_worker in self.generation_workers]
min_load = min(load)
candidates = [i for i, l in enumerate(load) if l == min_load]
# random tie break
self.load_balancer_idx = random.choice(candidates)
generation_worker = self.generation_workers[self.load_balancer_idx]
transformer_messages = []
for message in messages:
transformer_messages.append({"role": self.role_mapping[message.type], "content": message.content})
input_ids = self.tokenizer.apply_chat_template(
transformer_messages, return_tensors="pt", tokenize=True, add_generation_prompt=True
)
attention_mask = torch.ones_like(input_ids)
rollouts = ray.get(generation_worker.generate.remote(input_ids, attention_mask, **kwargs))
response = self.tokenizer.batch_decode(
rollouts["input_ids"][0][:, input_ids.size(-1) :], skip_special_tokens=True
)[0]
return response
class LangChainCustomLLM(ToolCallingLLM, BaseChatModel):
client: CustomOpenAIAPILLM = None
def __init__(self, client: CustomOpenAIAPILLM):
super().__init__()
self.client = client
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
# content = self.client.invoke([m.dict() for m in messages])
# chat_result = ChatResult(
# generations=[ChatGeneration(message=AIMessage(content=content))]
# )
print("messages:", messages)
breakpoint()
system_message, functions = self._generate_system_message_and_functions(kwargs)
sample_params = {"stop": stop} if stop is not None else {}
sample_params.update({k: v for k, v in kwargs.items() if k in ["temperature", "top_p", "top_k", "max_tokens"]})
messages_ = copy.deepcopy(messages)
messages_[0].content = messages_[0].content + "\n" + system_message.content
response_message = self.client.invoke( # type: ignore[safe-super]
[system_message] + messages, **{"sample_params": sample_params}
)
breakpoint()
response = self._process_response(AIMessage(content=response_message), functions)
return ChatResult(generations=[ChatGeneration(message=response)])
@property
def _llm_type(self) -> str:
return "custom-api-llm"
# -------------------------------
# 3. Build a ReAct Agent with LangGraph
# -------------------------------
def build_agent():
# Wrap custom API LLM in LangChain-compatible interface
# Init LLM
llm_client = CustomOpenAIAPILLM()
llm = LangChainCustomLLM(llm_client)
# Tools
tools = [python]
# Memory (optional)
memory = MemorySaver()
# Build ReAct agent
agent = create_react_agent(llm, tools, checkpointer=memory)
return agent
# -------------------------------
# 4. Run the agent on a math problem
# -------------------------------
if __name__ == "__main__":
agent = build_agent()
# Example math question
user_input = "What is the least common multiple of 18 and 24? Use Python if needed."
config = {"configurable": {"thread_id": "math-1"}}
for event in agent.stream({"messages": [("user", user_input)]}, config):
if "agent" in event:
print("Agent event:", event["agent"]["messages"][-1].content)
elif "tools" in event:
print("Tool event:", event["tools"]["messages"][-1].content)
final_state = agent.get_state(config)
print("Final Answer:", final_state["messages"][-1].content)

View File

@@ -0,0 +1,31 @@
from langchain_core.tools import Tool
from langchain_experimental.utilities import PythonREPL
from pydantic import BaseModel, Field
from pydantic.fields import FieldInfo
def make_title(field_name: str, field_info: FieldInfo) -> str:
return field_name
class PythonInput(BaseModel):
code: str = Field(description="The python code to execute", field_title_generator=make_title)
python_repl = PythonREPL()
def run_python_code(code: str) -> str:
if code.startswith("```python"):
code = code.replace("```python", "```", 1).strip()
if code.startswith("```py"): # qwen3 uses ```py
code = code.replace("```py", "```", 1).strip()
return python_repl.run(code, timeout=20)
repl_tool = Tool(
name="python_repl",
description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.",
func=run_python_code,
args_schema=PythonInput,
)

View File

@@ -1,7 +1,7 @@
from typing import Any, Dict
import ray
from coati.distributed.agent.agentic import BaseAgenticProducer
from coati.distributed.agent.base import BaseAgenticProducer
from coati.distributed.agent.qwen_math_agentic_utils import TIR_SYSTEM, CustomTransformers
from qwen_agent.agents import TIRMathAgent
@@ -24,6 +24,7 @@ class QwenMathAgenticProducer(BaseAgenticProducer):
model_config,
generate_config,
async_producers,
tool_workers=[],
tokenizer_config=None,
agentic_config=None,
microbatch_size=1,
@@ -85,4 +86,5 @@ class QwenMathAgenticProducer(BaseAgenticProducer):
for response in self.bot.run(messages):
continue
messages.extend(response)
# breakpoint()
return messages

View File

@@ -55,17 +55,18 @@ class LocalLLMFromGenerationWorkers:
A class that wraps the Transformers model to support API-based text generation.
"""
def __init__(self, generation_worker=None):
def __init__(self, generation_worker=None, tokenizer=None):
self.device = "cpu"
self.generation_worker = generation_worker
self.tokenizer = tokenizer
def generate(self, **kwargs):
breakpoint()
if "max_new_tokens" in kwargs:
# we use VLLM backend for generation, which uses `max_tokens`
kwargs["max_tokens"] = kwargs["max_new_tokens"]
del kwargs["max_new_tokens"]
rollouts = ray.get(self.generation_worker.generate.remote(**kwargs))
# breakpoint()
return rollouts["input_ids"]
@@ -108,7 +109,7 @@ class CustomTransformers(Transformers):
################################################################
self.generation_workers = generation_workers
self.hf_models = [
LocalLLMFromGenerationWorkers(generation_worker=generation_worker)
LocalLLMFromGenerationWorkers(generation_worker=generation_worker, tokenizer=self.tokenizer)
for generation_worker in generation_workers
]
self.producer_idx = producer_idx
@@ -133,10 +134,9 @@ class CustomTransformers(Transformers):
candidates = [i for i, l in enumerate(load) if l == min_load]
# random tie break
self.load_balancer_idx = random.choice(candidates)
# breakpoint()
response = self._chat_no_stream(messages=messages, generate_cfg=generate_cfg)
# if self.producer_idx == 0:
# print(response)
breakpoint()
# breakpoint()
yield response

View File

@@ -0,0 +1,77 @@
from typing import Any, Dict, List, Optional, Union
import ray
from langchain.tools import BaseTool
@ray.remote(concurrency_groups={"io": 1, "compute": 5})
class ToolWorker:
"""
A unified wrapper class for LangChain tools, enabling a standard
interface to call tools regardless of their internal differences.
"""
def __init__(self, tools: List[BaseTool]):
"""
Initialize ToolWorker with a list of LangChain tools.
Args:
tools (List[BaseTool]): List of LangChain tools to register.
"""
self._tool_registry: Dict[str, BaseTool] = {tool.name: tool for tool in tools}
self.pending = 0
@ray.method(concurrency_group="io")
def get_load(self) -> int:
"""Return the current load of the worker."""
return self.pending
@ray.method(concurrency_group="io")
def increase_load(self):
"""Increase the load counter."""
self.pending += 1
@ray.method(concurrency_group="io")
def list_tools(self) -> List[str]:
"""Return the list of available tool names."""
return list(self._tool_registry.keys())
@ray.method(concurrency_group="io")
def get_tool_description(self, tool_name: str) -> Optional[str]:
"""Return the description of a specific tool."""
tool = self._tool_registry.get(tool_name)
return tool.description if tool else None
@ray.method(concurrency_group="io")
def get_args_schema(self, tool_name: str):
"""Return the argument schema of a specific tool."""
assert tool_name in self._tool_registry, f"Tool '{tool_name}' not found. Available: {self.list_tools()}"
tool = self._tool_registry.get(tool_name)
schema = tool.args_schema.model_json_schema(by_alias=False)
return schema
@ray.method(concurrency_group="compute")
def call(self, tool_name: str, input_data: Union[str, Dict[str, Any]], **kwargs) -> Any:
"""
Call a tool by name with input data.
Args:
tool_name (str): Name of the tool to call.
input_data (Union[str, Dict[str, Any]]): Input to pass to the tool.
**kwargs: Extra keyword arguments for the tool.
Returns:
Any: The tool's output.
"""
if tool_name == "return_parsing_error":
self.pending -= 1
return "Error: Tool call parsing error. Please use the correct JSON format."
if tool_name not in self._tool_registry:
return f"Error: Tool {tool_name} not found. Available tools: {self.list_tools()}"
tool = self._tool_registry[tool_name]
try:
ret = tool.run(input_data, **kwargs)
except Exception as e:
ret = f"Error: Tool {tool_name} execution failed with error: {str(e)}"
self.pending -= 1
return ret

View File

@@ -344,7 +344,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
self.model_config = model_config
self.tokenizer = tokenizer
self.num_generations = num_generations
self.queued_requests = []
self.running_requests = []
self.microbatch_size = microbatch_size
self.profiler = profiler
@@ -358,8 +358,10 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
input_ids (torch.Tensor): shape [B, S], B=1
attention_mask (torch.Tensor): shape [B, S]
"""
# breakpoint()
assert input_ids.size(0) == attention_mask.size(0) == 1
request_id = (
str(uuid4()) if not "request_id" in kwargs else kwargs.pop("request_id")
) # use fixed request_id to reuse kv cache
response_start_idx = input_ids.size(1)
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
input_ids_no_padding = [input_ids.tolist()[0][first_non_padding_token_idx[0] :]]
@@ -373,10 +375,10 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
out_len = []
log_probs = []
response_idx = []
while len(self.queued_requests) >= self.microbatch_size:
while len(self.running_requests) >= self.microbatch_size:
# print(f"Current running {len(self.running_requests)}/{self.microbatch_size} requests, waiting...")
await asyncio.sleep(0.1)
request_id = str(uuid4())
self.queued_requests.append(request_id) # enqueue
self.running_requests.append(request_id) # enqueue
# pop the first input_ids and attention_mask
prompt_token_ids = input_ids_no_padding[0]
self.profiler.enter(f"vllm generate {request_id}")
@@ -386,14 +388,25 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
async for chunk in outputs:
# generate the output tokens, can yield to avoid blocking
pass
self.queued_requests.remove(request_id) # dequeue
for output_i in chunk.outputs:
self.running_requests.remove(request_id) # dequeue
if self.generate_config.get("prompt_logprobs", None) is not None:
# when prompt_logprobs is not None, vllm will return logprobs for the whole sequence
# for agentic producer, we return the logprobs of the whole sequence
log_probs = [
[m[t].logprob if m is not None else 0.0 for m, t in zip(chunk.prompt_logprobs, chunk.prompt_token_ids)]
]
for _ in range(sample_params.n - 1):
log_probs.append([t for t in log_probs[0]]) # repeat the same logprobs for num_generations times
else:
log_probs = [[] for _ in range(sample_params.n)]
for generation_id, output_i in enumerate(chunk.outputs):
out_len.append(len(output_i.token_ids))
out_tokens.append(list(output_i.token_ids))
response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
assert len(output_i.logprobs) == len(output_i.token_ids)
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
log_probs.append(p)
log_probs[generation_id].extend(p)
self.profiler.exit(f"vllm generate {request_id}")
# pad them
max_len = self.sample_params.max_tokens
@@ -402,7 +415,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
for i, new_token_ids in enumerate(out_tokens):
pad_len = max_len - out_len[i]
out_tokens[i] = new_token_ids + [self.tokenizer.pad_token_id] * pad_len
log_probs[i] = log_probs[i] + [0.0] * pad_len
log_probs[i] = log_probs[i] + [0.0] * (max_len - len(log_probs[i]))
action_mask[i, out_len[i] :] = 0
out_tokens = torch.tensor(out_tokens)

View File

@@ -4,8 +4,9 @@ import uuid
from typing import Any, Dict, Optional
import ray
from coati.distributed.agent.langgraph_math_agentic import LangGraphMathAgenticProducer
from coati.distributed.agent.qwen_math_agentic import QwenMathAgenticProducer
from coati.distributed.agent.agentic_producer import AgenticProducer
from coati.distributed.agent.qwen_math_agentic_producer import QwenMathAgenticProducer
from coati.distributed.agent.tool_worker import ToolWorker
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
@@ -21,7 +22,7 @@ ALGO_MAP = {
Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncSimpleProducer}
AGENTIC_PRODUCER_MAP = {
"QwenMathAgent": QwenMathAgenticProducer,
"LangGraphMathAgent": LangGraphMathAgenticProducer,
"Agentic": AgenticProducer,
} # supported agentic producers
@@ -165,21 +166,16 @@ def launch_distributed(
)
producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
"""
# test async generate
import torch
import asyncio
import time
async def test():
res_ref = producer_procs[0].generate.remote(torch.ones((2, 10), dtype=torch.int), torch.ones((2, 10), dtype=torch.int))
res = await res_ref
return res
res = asyncio.run(test())
print(res)
time.sleep(1000)
"""
if enable_agentic:
from coati.distributed.agent.math_tools import repl_tool
# setup tool workers
tool_workers = []
if agentic_config["agentic_producer"] == "Agentic":
# 10 tool workers can handle 50 queries simultaneously
# note that imported repl_tool will be serialized and deserialized in each tool worker, therefore all workers can run parallely
tool_workers = [ToolWorker.remote([repl_tool]) for _ in range(agentic_config.get("num_tool_workers", 10))]
# when agentic is enabled, we use core_producer as inference engine and
# AgenticProducer as the real producer
_producer_procs = producer_procs
@@ -194,7 +190,7 @@ def launch_distributed(
producer_procs = [
agentic_producer_cls.options(num_cpus=1).remote(
producer_idx=producer_idx,
num_producers=num_producers * train_batch_size,
num_producers=num_producers * inference_batch_size,
num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes,
batch_size=1, # batch_size must be 1 for agentic producer
@@ -202,6 +198,7 @@ def launch_distributed(
model_config=inference_model_config,
generate_config=generate_config,
async_producers=_producer_procs,
tool_workers=tool_workers,
tokenizer_config=tokenizer_config,
agentic_config=agentic_config,
microbatch_size=1, # microbatch_size must be 1 for agentic producer

View File

@@ -587,26 +587,6 @@ class BaseAsyncProducer(BaseProducer):
self.condition = asyncio.Condition()
self.data_ready_for_sending = []
# @torch.no_grad()
# async def generate(self, input_ids, attention_mask, **kwargs):
# tasks = []
# print("input_ids:", input_ids)
# for prompt_id in range(input_ids.size(0)):
# new_kwargs = copy.deepcopy(kwargs)
# if "gt_answer" in new_kwargs:
# new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id]
# if "test_cases" in new_kwargs:
# new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id]
# tasks.append(
# self.model.generate(
# input_ids[prompt_id].unsqueeze(0),
# attention_mask[prompt_id].unsqueeze(0),
# **new_kwargs,
# )
# )
# rollouts = await asyncio.gather(*tasks)
# return rollouts
@torch.no_grad()
async def generate(self, input_ids, attention_mask, **kwargs):
# naive rollout strategy
@@ -647,7 +627,7 @@ class BaseAsyncProducer(BaseProducer):
"""
Get the load of each producer.
"""
return len(self.model.queued_requests)
return len(self.model.running_requests)
async def async_sync_model(self, episode, step, num_processes: int = 1) -> None:
"""

View File

@@ -164,8 +164,8 @@ if __name__ == "__main__":
parser.add_argument(
"--agentic-type",
type=str,
default="QwenMathAgent",
choices=["QwenMathAgent", "LangGraphMathAgent"],
default="Agentic",
choices=["Agentic", "QwenMathAgent"],
help="Agentic model type for agentic training.",
)
parser.add_argument(
@@ -418,9 +418,6 @@ if __name__ == "__main__":
if "agentic" in args.backend:
assert "vllm" in args.backend, "Agentic backend only supports async-agentic-vllm backends."
generate_config["n"] = 1 # agentic producer use AsyncProducer which processes one request a time
generate_config["max_tokens"] = (
2048 # max new tokens for each agentic step, usually smaller than max_new_tokens as agentic model will generate multiple steps
)
if args.agentic_type == "QwenMathAgent":
agentic_config = {
"agentic_producer": "QwenMathAgent",
@@ -433,8 +430,15 @@ if __name__ == "__main__":
agentic_config["generate_cfg"].update(
{k: v for k, v in generate_config.items() if k in ["top_k", "top_p", "temperature"]}
)
elif args.agentic_type == "LangGraphMathAgent":
agentic_config = {"configurable": {"thread_id": "math-1"}, "agentic_producer": "LangGraphMathAgent"}
elif args.agentic_type == "Agentic":
generate_config["stop"] = ["<|im_end|>"]
generate_config["prompt_logprobs"] = 0
agentic_config = {
"agentic_producer": "Agentic",
"tool_call_budget": 5,
"llm_call_budget": 10,
"max_tokens": 2048,
}
else:
raise ValueError(f"Unsupported agentic model type: {args.agentic_type}")
else: