mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-23 05:06:26 +00:00
add custom agentic producer
This commit is contained in:
@@ -0,0 +1,287 @@
|
||||
import copy
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
import ray
|
||||
from coati.distributed.agent.base import BaseAgenticProducer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGE = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <reason> </reason> and <answer> </answer> tags, respectively, i.e., <reason> reasoning process here </reason><answer> answer here </answer>."""
|
||||
|
||||
|
||||
@ray.remote
|
||||
class AgenticProducer(BaseAgenticProducer):
|
||||
"""
|
||||
Asyncronous version of the producer that uses vLLM for generation.
|
||||
This class is designed to generate agentic response
|
||||
|
||||
Please use the following SYSTEM message or a similar one for the agentic math model:
|
||||
'''A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
||||
The Assistant first thinks about the reasoning process in the mind and then provides the user with
|
||||
the answer. The reasoning process and answer are enclosed within <reason> </reason> and <answer>
|
||||
</answer> tags, respectively, i.e., <reason> reasoning process here </reason><answer> answer here </answer>.'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
async_producers,
|
||||
tool_workers=[],
|
||||
tokenizer_config=None,
|
||||
agentic_config=None,
|
||||
microbatch_size=1,
|
||||
backend="transformers",
|
||||
num_generations: int = 8,
|
||||
consumer_plugin_config=None,
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
grpo_config: Dict[str, Any] = None,
|
||||
eval_save_dir: str = "./eval",
|
||||
eval_generation_config={},
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
enable_profiling: bool = False,
|
||||
n_behind: int = 0,
|
||||
):
|
||||
assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
|
||||
assert batch_size == 1 # batch_size must be 1 for agentic producer
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
async_producers,
|
||||
tokenizer_config,
|
||||
microbatch_size,
|
||||
backend,
|
||||
num_generations,
|
||||
consumer_plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
grpo_config=grpo_config,
|
||||
eval_save_dir=eval_save_dir,
|
||||
eval_generation_config=eval_generation_config,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
enable_profiling=enable_profiling,
|
||||
n_behind=n_behind,
|
||||
)
|
||||
self.tool_workers = tool_workers
|
||||
self.agentic_config = model_config if not agentic_config else agentic_config
|
||||
self.agentic_config.update({"model": model_config["path"]})
|
||||
tokenizer_path = None
|
||||
if tokenizer_config and "path" in tokenizer_config:
|
||||
tokenizer_path = tokenizer_config["path"]
|
||||
elif "path" in model_config:
|
||||
tokenizer_path = model_config["path"]
|
||||
assert tokenizer_path is not None, "Tokenizer path must be provided either in tokenizer_config or model_config."
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
|
||||
self.tools_schema = []
|
||||
self.tool_call_budget = self.agentic_config.get("tool_call_budget", 3)
|
||||
self.llm_call_budget = self.agentic_config.get("llm_call_budget", 10)
|
||||
self.async_llm_engine_map = {}
|
||||
self._get_tools()
|
||||
|
||||
def _get_tools(self):
|
||||
"""
|
||||
SYSTEM message for the agentic math model. Reference: r-start2 paper https://arxiv.org/pdf/2508.20722
|
||||
"""
|
||||
tools = ray.get(self.tool_workers[0].list_tools.remote())
|
||||
tool_descriptions = {tool: ray.get(self.tool_workers[0].get_tool_description.remote(tool)) for tool in tools}
|
||||
tool_arg_schemas = {tool: ray.get(self.tool_workers[0].get_args_schema.remote(tool)) for tool in tools}
|
||||
self.tools = []
|
||||
for tool in tools:
|
||||
tool_schema = {"name": tool, "description": tool_descriptions[tool], "parameters": tool_arg_schemas[tool]}
|
||||
self.tools.append(tool_schema)
|
||||
|
||||
def _build_prompt(
|
||||
self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt"
|
||||
) -> dict:
|
||||
"""
|
||||
Build the prompt for the agentic math model.
|
||||
"""
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=self.tools,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
return_dict=return_dict,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
def _parse_response(self, response: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse the response from the agentic math model.
|
||||
|
||||
Sample Assistant Response:
|
||||
The tool indicates that Singapore’s weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}<|im_end|>
|
||||
|
||||
Sample Assistant Response with Tool Call:
|
||||
To answer this, I will check both the weather and the timezone for New York.\n<tool_call>\n{"name": "get_weather", "arguments": {"location": "New York"}}\n</tool_call>\n<tool_call>\n{"name": "get_timezone", "arguments": {"location": "New York"}}\n</tool_call>
|
||||
|
||||
Sample Ouput:
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check the current weather in Singapore by calling the weather tool.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"location": "New York"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"function": {
|
||||
"name": "get_timezone",
|
||||
"arguments": {
|
||||
"location": "New York"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The tool indicates that Singapore’s weather today is 31°C with partly cloudy skies and light showers. \\\\boxed{It is warm and slightly rainy in Singapore today.}"
|
||||
}
|
||||
"""
|
||||
# split by <im_end|>
|
||||
response_chunked = response.split("<|im_end|>")[0].strip()
|
||||
if "<tool_call>" in response_chunked:
|
||||
assistant_content = response_chunked.split("<tool_call>")[0].strip()
|
||||
tool_call_sections = response_chunked[response_chunked.find("<tool_call>") :].strip()
|
||||
# extract all tool calls
|
||||
tool_calls = []
|
||||
pattern = "<tool_call>(.*?)</tool_call>"
|
||||
matches = re.findall(pattern, tool_call_sections, re.DOTALL)
|
||||
for match in matches:
|
||||
try:
|
||||
tool_call = eval(match.strip())
|
||||
name = tool_call["name"]
|
||||
arguments = tool_call["arguments"]
|
||||
tool_calls.append({"function": {"name": name, "arguments": arguments}})
|
||||
except Exception as e:
|
||||
print(f"Failed to parse tool call: {match.strip()}. Error: {e}")
|
||||
tool_calls.append({"function": {"name": "return_parsing_error", "arguments": {}}})
|
||||
else:
|
||||
assistant_content = response_chunked
|
||||
tool_calls = []
|
||||
assistant_message = {"role": "assistant", "content": assistant_content}
|
||||
if tool_calls:
|
||||
assistant_message["tool_calls"] = tool_calls
|
||||
return assistant_message
|
||||
|
||||
def _select_tool_worker(self) -> ray.actor.ActorHandle:
|
||||
"""
|
||||
Select a tool worker based on the current load.
|
||||
"""
|
||||
loads = ray.get([worker.get_load.remote() for worker in self.tool_workers])
|
||||
min_load = min(loads)
|
||||
candidates = [i for i, l in enumerate(loads) if l == min_load]
|
||||
selected_idx = random.choice(candidates) # random tie break
|
||||
ray.get(self.tool_workers[selected_idx].increase_load.remote())
|
||||
return self.tool_workers[selected_idx]
|
||||
|
||||
def _select_async_producer(self, request_id) -> ray.actor.ActorHandle:
|
||||
"""
|
||||
Select an async producer based on the current load.
|
||||
"""
|
||||
# use the last used async producer if exists to reuse kv cache (as vllm use paged kv cache,
|
||||
# it will reuse most of the kv cache pages without recomputation)
|
||||
if request_id in self.async_llm_engine_map:
|
||||
return self.async_producers[self.async_llm_engine_map[request_id]]
|
||||
# otherwise select the least loaded async producer
|
||||
loads = ray.get([proc.get_producer_load.remote() for proc in self.async_producers])
|
||||
min_load = min(loads)
|
||||
candidates = [i for i, l in enumerate(loads) if l == min_load]
|
||||
selected_idx = random.choice(candidates) # random tie break
|
||||
self.async_llm_engine_map[request_id] = selected_idx
|
||||
return self.async_producers[selected_idx]
|
||||
|
||||
def _run_agentic_pipeline(self, messages):
|
||||
"""
|
||||
Run the agentic pipeline to generate responses based on the input messages.
|
||||
"""
|
||||
tool_call_count = 0
|
||||
llm_call_count = 0
|
||||
num_prompt_tokens = 0
|
||||
request_id = str(uuid4())
|
||||
logprobs = None
|
||||
while True:
|
||||
# tokenize the messages
|
||||
if llm_call_count > self.llm_call_budget:
|
||||
print(f"LLM call budget exceeded: {llm_call_count} > {self.llm_call_budget}. Stopping.")
|
||||
del self.async_llm_engine_map[request_id]
|
||||
while messages[-1]["role"] == "tool":
|
||||
messages.pop()
|
||||
return messages, logprobs
|
||||
inputs = self._build_prompt(messages, return_dict=True, return_tensors="pt")
|
||||
if num_prompt_tokens == 0:
|
||||
num_prompt_tokens = inputs["input_ids"].size(-1)
|
||||
if inputs["input_ids"].size(-1) - num_prompt_tokens > self.generate_config["max_tokens"]:
|
||||
print(
|
||||
f"Max tokens exceeded: Current have generated {inputs['input_ids'].size(-1) - num_prompt_tokens} tokens > {self.generate_config.get('max_tokens', 512)}. Stopping."
|
||||
)
|
||||
del self.async_llm_engine_map[request_id]
|
||||
while messages[-1]["role"] == "tool":
|
||||
messages.pop()
|
||||
return messages, logprobs
|
||||
async_producer = self._select_async_producer(request_id=request_id)
|
||||
agentic_generate_config = copy.deepcopy(self.generate_config)
|
||||
agentic_generate_config["max_tokens"] = self.agentic_config.get("max_tokens", 2048)
|
||||
response = ray.get(
|
||||
async_producer.generate.remote(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
request_id=request_id,
|
||||
**agentic_generate_config,
|
||||
)
|
||||
)
|
||||
llm_call_count += 1
|
||||
response_input_ids = response["input_ids"]
|
||||
logprobs = response["action_log_probs"]
|
||||
response_text = self.tokenizer.decode(
|
||||
response_input_ids[0][0][inputs["input_ids"].size(-1) :], skip_special_tokens=False
|
||||
)
|
||||
assistant_message = self._parse_response(response_text)
|
||||
messages.append(assistant_message)
|
||||
if "tool_calls" in assistant_message:
|
||||
if tool_call_count > self.tool_call_budget:
|
||||
print(f"Tool call budget exceeded: {tool_call_count} > {self.tool_call_budget}. Stopping.")
|
||||
del self.async_llm_engine_map[request_id]
|
||||
return messages, logprobs
|
||||
tool_call_count += len(assistant_message["tool_calls"])
|
||||
handlers = []
|
||||
for tool_call in assistant_message["tool_calls"]:
|
||||
# select a tool worker to execute the tool call
|
||||
tool_worker = self._select_tool_worker()
|
||||
handler = tool_worker.call.remote(tool_call["function"]["name"], tool_call["function"]["arguments"])
|
||||
handlers.append(handler)
|
||||
tool_results = ray.get(handlers)
|
||||
for tool_call, tool_result in zip(assistant_message["tool_calls"], tool_results):
|
||||
tool_message = {"role": "tool", "content": str(tool_result)}
|
||||
messages.append(tool_message)
|
||||
else:
|
||||
# no further tool call, return the messages
|
||||
del self.async_llm_engine_map[request_id]
|
||||
return messages, logprobs
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict
|
||||
|
||||
import ray
|
||||
@@ -86,6 +87,15 @@ class BaseAgenticProducer(BaseProducer):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _build_prompt(
|
||||
self, messages, add_generation_prompt: bool = True, return_dict=True, return_tensors="pt"
|
||||
) -> dict:
|
||||
"""
|
||||
Build the prompt from the input messages.
|
||||
This function should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Rollout function to generate responses for the input, for example, using LLM or agentic pipeline.
|
||||
@@ -93,9 +103,9 @@ class BaseAgenticProducer(BaseProducer):
|
||||
"""
|
||||
assert len(kwargs["messages"]) == 1, "Only support batch size of 1 for agentic producer"
|
||||
messages = kwargs["messages"][0]
|
||||
prompt_input_ids = self.tokenizer.apply_chat_template(
|
||||
messages, return_tensors="pt", tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
prompt_input_ids = self._build_prompt(
|
||||
messages, return_dict=True, return_tensors="pt", add_generation_prompt=True
|
||||
)["input_ids"]
|
||||
# add left padding
|
||||
prompt_length = prompt_input_ids.shape[1]
|
||||
max_prompt_length = self.train_dataset_config["max_length"]
|
||||
@@ -107,10 +117,16 @@ class BaseAgenticProducer(BaseProducer):
|
||||
"action_log_probs": [],
|
||||
"response_idx": [],
|
||||
}
|
||||
with ThreadPoolExecutor(max_workers=self.num_generations) as executor:
|
||||
results = list(
|
||||
executor.map(self._run_agentic_pipeline, [copy.deepcopy(messages) for _ in range(self.num_generations)])
|
||||
)
|
||||
|
||||
for i in range(self.num_generations):
|
||||
_messages = copy.deepcopy(messages)
|
||||
_messages = self._run_agentic_pipeline(_messages)
|
||||
response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True)
|
||||
_messages, logprobs = results[i]
|
||||
response_input_ids = self._build_prompt(
|
||||
_messages, return_dict=True, return_tensors="pt", add_generation_prompt=False
|
||||
)["input_ids"]
|
||||
# truncate if too long
|
||||
response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left]
|
||||
# add left right padding
|
||||
@@ -127,9 +143,14 @@ class BaseAgenticProducer(BaseProducer):
|
||||
) # [1, max_length-prompt_length]
|
||||
rollouts["attention_mask"].append(attention_mask)
|
||||
rollouts["action_mask"].append(action_mask)
|
||||
rollouts["action_log_probs"].append(
|
||||
torch.ones(size=(1, self.grpo_config["max_length"] - max_prompt_length))
|
||||
) # dummy log probs
|
||||
truncated_logprobs = logprobs[:, :, prompt_length : prompt_length + self.generate_config["max_tokens"]]
|
||||
logprobs_padded = torch.nn.functional.pad(
|
||||
truncated_logprobs,
|
||||
(0, self.generate_config["max_tokens"] - truncated_logprobs.size(-1)),
|
||||
"constant",
|
||||
value=0.0,
|
||||
) # [1, max_new_tokens]
|
||||
rollouts["action_log_probs"].append(logprobs_padded[0])
|
||||
rollouts["response_idx"].append(
|
||||
torch.tensor(
|
||||
[
|
||||
@@ -141,7 +162,6 @@ class BaseAgenticProducer(BaseProducer):
|
||||
)
|
||||
) # [1, 2]
|
||||
rollouts["input_ids"].append(input_ids)
|
||||
# breakpoint()
|
||||
rollouts = {k: torch.cat(v, dim=0).unsqueeze(0) for k, v in rollouts.items()} # [num_generations, ...]
|
||||
rollouts["temperature"] = torch.tensor([self.agentic_config.get("temperature", 1.0)])
|
||||
if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode:
|
||||
@@ -1,122 +0,0 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import ray
|
||||
from coati.distributed.agent.agentic import BaseAgenticProducer
|
||||
from coati.distributed.agent.langgraph_math_agentic_utils import CustomOpenAIAPILLM, LangChainCustomLLM, python
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
|
||||
@ray.remote
|
||||
class LangGraphMathAgenticProducer(BaseAgenticProducer):
|
||||
"""
|
||||
Asyncronous version of the producer that uses vLLM for generation.
|
||||
This class is designed to generate agentic response
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
async_producers,
|
||||
tokenizer_config=None,
|
||||
agentic_config=None,
|
||||
microbatch_size=1,
|
||||
backend="transformers",
|
||||
num_generations: int = 8,
|
||||
consumer_plugin_config=None,
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
grpo_config: Dict[str, Any] = None,
|
||||
eval_save_dir: str = "./eval",
|
||||
eval_generation_config={},
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
enable_profiling: bool = False,
|
||||
n_behind: int = 0,
|
||||
):
|
||||
assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
|
||||
assert batch_size == 1 # batch_size must be 1 for agentic producer
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
async_producers,
|
||||
tokenizer_config,
|
||||
microbatch_size,
|
||||
backend,
|
||||
num_generations,
|
||||
consumer_plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
grpo_config=grpo_config,
|
||||
eval_save_dir=eval_save_dir,
|
||||
eval_generation_config=eval_generation_config,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
enable_profiling=enable_profiling,
|
||||
n_behind=n_behind,
|
||||
)
|
||||
self.agentic_config = agentic_config
|
||||
self.agentic_config.pop("agentic_type", None)
|
||||
self.llm_client = CustomOpenAIAPILLM({"model": model_config["path"]}, producer_idx, self.async_producers)
|
||||
self.llm = LangChainCustomLLM(self.llm_client)
|
||||
# self.python_repl = PythonREPL()
|
||||
# repl_tool = Tool(
|
||||
# name="python_repl",
|
||||
# description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.",
|
||||
# func=self.python_repl.run,
|
||||
# )
|
||||
# self.tools = [repl_tool]
|
||||
self.tools = [python]
|
||||
self.memory = MemorySaver()
|
||||
self.bot = create_react_agent(self.llm, self.tools, checkpointer=self.memory)
|
||||
|
||||
def _run_agentic_pipeline(self, messages):
|
||||
"""
|
||||
Run the agentic pipeline to generate responses based on the input messages using the LangGraph.
|
||||
"""
|
||||
assert (
|
||||
len(messages) == 2 and messages[0]["role"] == "system" and messages[1]["role"] == "user"
|
||||
), "Only support 1 system message and 1 user message as input."
|
||||
# inputs = messages
|
||||
for event in self.bot.stream(
|
||||
{"messages": [("system", messages[0]["content"]), ("user", "calculate the 1000th Fibonacci number")]},
|
||||
self.agentic_config,
|
||||
):
|
||||
continue
|
||||
breakpoint()
|
||||
|
||||
final_state = self.bot.get_state(self.agentic_config)
|
||||
transformer_messages = []
|
||||
for message in final_state[0]["messages"]:
|
||||
tool_calls = None
|
||||
if isinstance(message, SystemMessage):
|
||||
message.content
|
||||
elif isinstance(message, HumanMessage):
|
||||
message.content
|
||||
elif isinstance(message, AIMessage):
|
||||
message.content
|
||||
tool_calls = message.get("tool_calls", None) # [{"type": "function", "function": tool_call}]
|
||||
elif isinstance(message, ToolMessage):
|
||||
message.content
|
||||
|
||||
return transformer_messages
|
||||
@@ -1,185 +0,0 @@
|
||||
# -------------------------------
|
||||
# 1. Define the Python tool
|
||||
# -------------------------------
|
||||
import copy
|
||||
import io
|
||||
import random
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs.chat_generation import ChatGeneration
|
||||
from langchain_core.outputs.chat_result import ChatResult
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from tool_calling_llm import ToolCallingLLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATE = """{task_description}. You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
Use the following format:
|
||||
|
||||
Question: the input question you must answer
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
|
||||
Begin!
|
||||
|
||||
Question: {input}
|
||||
Thought:{agent_scratchpad}"""
|
||||
|
||||
SYSTEM_PROMPT = PromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE)
|
||||
|
||||
|
||||
class Capturing(list):
|
||||
"""Capture stdout prints inside exec()"""
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout = sys.stdout
|
||||
sys.stdout = self._stringio = io.StringIO()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.extend(self._stringio.getvalue().splitlines())
|
||||
sys.stdout = self._stdout
|
||||
|
||||
|
||||
@tool
|
||||
def python(code: str) -> str:
|
||||
"""
|
||||
This function executes a string of Python code and returns the printed output.
|
||||
You need to print the output. Please import all libraries used in the code string.
|
||||
"""
|
||||
local_vars = {}
|
||||
with Capturing() as output:
|
||||
exec(code, {}, local_vars)
|
||||
if output == []:
|
||||
return "Error: No output printed from the code. Please ensure you print the output."
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# 2. Define a Custom API LLM wrapper
|
||||
# -------------------------------
|
||||
class CustomOpenAIAPILLM:
|
||||
def __init__(self, cfg: dict, producer_idx, generation_workers=None):
|
||||
self.producer_idx = producer_idx
|
||||
self.generation_workers = generation_workers
|
||||
self.load_balancer_idx = producer_idx % len(self.generation_workers)
|
||||
assert "model" in cfg, "Please specify the model name in the config"
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(cfg["model"])
|
||||
self.role_mapping = {
|
||||
"system": "system",
|
||||
"user": "user",
|
||||
"assistant": "assistant",
|
||||
"human": "user",
|
||||
"tool": "tool",
|
||||
}
|
||||
|
||||
def invoke(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||
"""
|
||||
messages: list of {"role": "user"/"assistant"/"system", "content": "..."}
|
||||
"""
|
||||
# load balancing
|
||||
load = [ray.get(generation_worker.get_producer_load.remote()) for generation_worker in self.generation_workers]
|
||||
min_load = min(load)
|
||||
candidates = [i for i, l in enumerate(load) if l == min_load]
|
||||
# random tie break
|
||||
self.load_balancer_idx = random.choice(candidates)
|
||||
generation_worker = self.generation_workers[self.load_balancer_idx]
|
||||
transformer_messages = []
|
||||
for message in messages:
|
||||
transformer_messages.append({"role": self.role_mapping[message.type], "content": message.content})
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
transformer_messages, return_tensors="pt", tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
rollouts = ray.get(generation_worker.generate.remote(input_ids, attention_mask, **kwargs))
|
||||
response = self.tokenizer.batch_decode(
|
||||
rollouts["input_ids"][0][:, input_ids.size(-1) :], skip_special_tokens=True
|
||||
)[0]
|
||||
return response
|
||||
|
||||
|
||||
class LangChainCustomLLM(ToolCallingLLM, BaseChatModel):
|
||||
client: CustomOpenAIAPILLM = None
|
||||
|
||||
def __init__(self, client: CustomOpenAIAPILLM):
|
||||
super().__init__()
|
||||
self.client = client
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
# content = self.client.invoke([m.dict() for m in messages])
|
||||
# chat_result = ChatResult(
|
||||
# generations=[ChatGeneration(message=AIMessage(content=content))]
|
||||
# )
|
||||
print("messages:", messages)
|
||||
breakpoint()
|
||||
system_message, functions = self._generate_system_message_and_functions(kwargs)
|
||||
sample_params = {"stop": stop} if stop is not None else {}
|
||||
sample_params.update({k: v for k, v in kwargs.items() if k in ["temperature", "top_p", "top_k", "max_tokens"]})
|
||||
messages_ = copy.deepcopy(messages)
|
||||
messages_[0].content = messages_[0].content + "\n" + system_message.content
|
||||
response_message = self.client.invoke( # type: ignore[safe-super]
|
||||
[system_message] + messages, **{"sample_params": sample_params}
|
||||
)
|
||||
breakpoint()
|
||||
response = self._process_response(AIMessage(content=response_message), functions)
|
||||
return ChatResult(generations=[ChatGeneration(message=response)])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "custom-api-llm"
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# 3. Build a ReAct Agent with LangGraph
|
||||
# -------------------------------
|
||||
def build_agent():
|
||||
# Wrap custom API LLM in LangChain-compatible interface
|
||||
|
||||
# Init LLM
|
||||
llm_client = CustomOpenAIAPILLM()
|
||||
llm = LangChainCustomLLM(llm_client)
|
||||
|
||||
# Tools
|
||||
tools = [python]
|
||||
|
||||
# Memory (optional)
|
||||
memory = MemorySaver()
|
||||
|
||||
# Build ReAct agent
|
||||
agent = create_react_agent(llm, tools, checkpointer=memory)
|
||||
return agent
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# 4. Run the agent on a math problem
|
||||
# -------------------------------
|
||||
if __name__ == "__main__":
|
||||
agent = build_agent()
|
||||
|
||||
# Example math question
|
||||
user_input = "What is the least common multiple of 18 and 24? Use Python if needed."
|
||||
|
||||
config = {"configurable": {"thread_id": "math-1"}}
|
||||
for event in agent.stream({"messages": [("user", user_input)]}, config):
|
||||
if "agent" in event:
|
||||
print("Agent event:", event["agent"]["messages"][-1].content)
|
||||
elif "tools" in event:
|
||||
print("Tool event:", event["tools"]["messages"][-1].content)
|
||||
|
||||
final_state = agent.get_state(config)
|
||||
print("Final Answer:", final_state["messages"][-1].content)
|
||||
@@ -0,0 +1,31 @@
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_experimental.utilities import PythonREPL
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
def make_title(field_name: str, field_info: FieldInfo) -> str:
|
||||
return field_name
|
||||
|
||||
|
||||
class PythonInput(BaseModel):
|
||||
code: str = Field(description="The python code to execute", field_title_generator=make_title)
|
||||
|
||||
|
||||
python_repl = PythonREPL()
|
||||
|
||||
|
||||
def run_python_code(code: str) -> str:
|
||||
if code.startswith("```python"):
|
||||
code = code.replace("```python", "```", 1).strip()
|
||||
if code.startswith("```py"): # qwen3 uses ```py
|
||||
code = code.replace("```py", "```", 1).strip()
|
||||
return python_repl.run(code, timeout=20)
|
||||
|
||||
|
||||
repl_tool = Tool(
|
||||
name="python_repl",
|
||||
description="A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.",
|
||||
func=run_python_code,
|
||||
args_schema=PythonInput,
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import ray
|
||||
from coati.distributed.agent.agentic import BaseAgenticProducer
|
||||
from coati.distributed.agent.base import BaseAgenticProducer
|
||||
from coati.distributed.agent.qwen_math_agentic_utils import TIR_SYSTEM, CustomTransformers
|
||||
from qwen_agent.agents import TIRMathAgent
|
||||
|
||||
@@ -24,6 +24,7 @@ class QwenMathAgenticProducer(BaseAgenticProducer):
|
||||
model_config,
|
||||
generate_config,
|
||||
async_producers,
|
||||
tool_workers=[],
|
||||
tokenizer_config=None,
|
||||
agentic_config=None,
|
||||
microbatch_size=1,
|
||||
@@ -85,4 +86,5 @@ class QwenMathAgenticProducer(BaseAgenticProducer):
|
||||
for response in self.bot.run(messages):
|
||||
continue
|
||||
messages.extend(response)
|
||||
# breakpoint()
|
||||
return messages
|
||||
@@ -55,17 +55,18 @@ class LocalLLMFromGenerationWorkers:
|
||||
A class that wraps the Transformers model to support API-based text generation.
|
||||
"""
|
||||
|
||||
def __init__(self, generation_worker=None):
|
||||
def __init__(self, generation_worker=None, tokenizer=None):
|
||||
self.device = "cpu"
|
||||
self.generation_worker = generation_worker
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def generate(self, **kwargs):
|
||||
breakpoint()
|
||||
if "max_new_tokens" in kwargs:
|
||||
# we use VLLM backend for generation, which uses `max_tokens`
|
||||
kwargs["max_tokens"] = kwargs["max_new_tokens"]
|
||||
del kwargs["max_new_tokens"]
|
||||
rollouts = ray.get(self.generation_worker.generate.remote(**kwargs))
|
||||
# breakpoint()
|
||||
return rollouts["input_ids"]
|
||||
|
||||
|
||||
@@ -108,7 +109,7 @@ class CustomTransformers(Transformers):
|
||||
################################################################
|
||||
self.generation_workers = generation_workers
|
||||
self.hf_models = [
|
||||
LocalLLMFromGenerationWorkers(generation_worker=generation_worker)
|
||||
LocalLLMFromGenerationWorkers(generation_worker=generation_worker, tokenizer=self.tokenizer)
|
||||
for generation_worker in generation_workers
|
||||
]
|
||||
self.producer_idx = producer_idx
|
||||
@@ -133,10 +134,9 @@ class CustomTransformers(Transformers):
|
||||
candidates = [i for i, l in enumerate(load) if l == min_load]
|
||||
# random tie break
|
||||
self.load_balancer_idx = random.choice(candidates)
|
||||
# breakpoint()
|
||||
response = self._chat_no_stream(messages=messages, generate_cfg=generate_cfg)
|
||||
# if self.producer_idx == 0:
|
||||
# print(response)
|
||||
breakpoint()
|
||||
# breakpoint()
|
||||
yield response
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import ray
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={"io": 1, "compute": 5})
|
||||
class ToolWorker:
|
||||
"""
|
||||
A unified wrapper class for LangChain tools, enabling a standard
|
||||
interface to call tools regardless of their internal differences.
|
||||
"""
|
||||
|
||||
def __init__(self, tools: List[BaseTool]):
|
||||
"""
|
||||
Initialize ToolWorker with a list of LangChain tools.
|
||||
|
||||
Args:
|
||||
tools (List[BaseTool]): List of LangChain tools to register.
|
||||
"""
|
||||
self._tool_registry: Dict[str, BaseTool] = {tool.name: tool for tool in tools}
|
||||
self.pending = 0
|
||||
|
||||
@ray.method(concurrency_group="io")
|
||||
def get_load(self) -> int:
|
||||
"""Return the current load of the worker."""
|
||||
return self.pending
|
||||
|
||||
@ray.method(concurrency_group="io")
|
||||
def increase_load(self):
|
||||
"""Increase the load counter."""
|
||||
self.pending += 1
|
||||
|
||||
@ray.method(concurrency_group="io")
|
||||
def list_tools(self) -> List[str]:
|
||||
"""Return the list of available tool names."""
|
||||
return list(self._tool_registry.keys())
|
||||
|
||||
@ray.method(concurrency_group="io")
|
||||
def get_tool_description(self, tool_name: str) -> Optional[str]:
|
||||
"""Return the description of a specific tool."""
|
||||
tool = self._tool_registry.get(tool_name)
|
||||
return tool.description if tool else None
|
||||
|
||||
@ray.method(concurrency_group="io")
|
||||
def get_args_schema(self, tool_name: str):
|
||||
"""Return the argument schema of a specific tool."""
|
||||
assert tool_name in self._tool_registry, f"Tool '{tool_name}' not found. Available: {self.list_tools()}"
|
||||
tool = self._tool_registry.get(tool_name)
|
||||
schema = tool.args_schema.model_json_schema(by_alias=False)
|
||||
return schema
|
||||
|
||||
@ray.method(concurrency_group="compute")
|
||||
def call(self, tool_name: str, input_data: Union[str, Dict[str, Any]], **kwargs) -> Any:
|
||||
"""
|
||||
Call a tool by name with input data.
|
||||
|
||||
Args:
|
||||
tool_name (str): Name of the tool to call.
|
||||
input_data (Union[str, Dict[str, Any]]): Input to pass to the tool.
|
||||
**kwargs: Extra keyword arguments for the tool.
|
||||
|
||||
Returns:
|
||||
Any: The tool's output.
|
||||
"""
|
||||
if tool_name == "return_parsing_error":
|
||||
self.pending -= 1
|
||||
return "Error: Tool call parsing error. Please use the correct JSON format."
|
||||
if tool_name not in self._tool_registry:
|
||||
return f"Error: Tool {tool_name} not found. Available tools: {self.list_tools()}"
|
||||
tool = self._tool_registry[tool_name]
|
||||
try:
|
||||
ret = tool.run(input_data, **kwargs)
|
||||
except Exception as e:
|
||||
ret = f"Error: Tool {tool_name} execution failed with error: {str(e)}"
|
||||
self.pending -= 1
|
||||
return ret
|
||||
@@ -344,7 +344,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
self.queued_requests = []
|
||||
self.running_requests = []
|
||||
self.microbatch_size = microbatch_size
|
||||
self.profiler = profiler
|
||||
|
||||
@@ -358,8 +358,10 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
input_ids (torch.Tensor): shape [B, S], B=1
|
||||
attention_mask (torch.Tensor): shape [B, S]
|
||||
"""
|
||||
# breakpoint()
|
||||
assert input_ids.size(0) == attention_mask.size(0) == 1
|
||||
request_id = (
|
||||
str(uuid4()) if not "request_id" in kwargs else kwargs.pop("request_id")
|
||||
) # use fixed request_id to reuse kv cache
|
||||
response_start_idx = input_ids.size(1)
|
||||
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
|
||||
input_ids_no_padding = [input_ids.tolist()[0][first_non_padding_token_idx[0] :]]
|
||||
@@ -373,10 +375,10 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
out_len = []
|
||||
log_probs = []
|
||||
response_idx = []
|
||||
while len(self.queued_requests) >= self.microbatch_size:
|
||||
while len(self.running_requests) >= self.microbatch_size:
|
||||
# print(f"Current running {len(self.running_requests)}/{self.microbatch_size} requests, waiting...")
|
||||
await asyncio.sleep(0.1)
|
||||
request_id = str(uuid4())
|
||||
self.queued_requests.append(request_id) # enqueue
|
||||
self.running_requests.append(request_id) # enqueue
|
||||
# pop the first input_ids and attention_mask
|
||||
prompt_token_ids = input_ids_no_padding[0]
|
||||
self.profiler.enter(f"vllm generate {request_id}")
|
||||
@@ -386,14 +388,25 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
async for chunk in outputs:
|
||||
# generate the output tokens, can yield to avoid blocking
|
||||
pass
|
||||
self.queued_requests.remove(request_id) # dequeue
|
||||
for output_i in chunk.outputs:
|
||||
self.running_requests.remove(request_id) # dequeue
|
||||
if self.generate_config.get("prompt_logprobs", None) is not None:
|
||||
# when prompt_logprobs is not None, vllm will return logprobs for the whole sequence
|
||||
# for agentic producer, we return the logprobs of the whole sequence
|
||||
log_probs = [
|
||||
[m[t].logprob if m is not None else 0.0 for m, t in zip(chunk.prompt_logprobs, chunk.prompt_token_ids)]
|
||||
]
|
||||
for _ in range(sample_params.n - 1):
|
||||
log_probs.append([t for t in log_probs[0]]) # repeat the same logprobs for num_generations times
|
||||
else:
|
||||
log_probs = [[] for _ in range(sample_params.n)]
|
||||
|
||||
for generation_id, output_i in enumerate(chunk.outputs):
|
||||
out_len.append(len(output_i.token_ids))
|
||||
out_tokens.append(list(output_i.token_ids))
|
||||
response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
|
||||
assert len(output_i.logprobs) == len(output_i.token_ids)
|
||||
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
|
||||
log_probs.append(p)
|
||||
log_probs[generation_id].extend(p)
|
||||
self.profiler.exit(f"vllm generate {request_id}")
|
||||
# pad them
|
||||
max_len = self.sample_params.max_tokens
|
||||
@@ -402,7 +415,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
for i, new_token_ids in enumerate(out_tokens):
|
||||
pad_len = max_len - out_len[i]
|
||||
out_tokens[i] = new_token_ids + [self.tokenizer.pad_token_id] * pad_len
|
||||
log_probs[i] = log_probs[i] + [0.0] * pad_len
|
||||
log_probs[i] = log_probs[i] + [0.0] * (max_len - len(log_probs[i]))
|
||||
action_mask[i, out_len[i] :] = 0
|
||||
|
||||
out_tokens = torch.tensor(out_tokens)
|
||||
|
||||
@@ -4,8 +4,9 @@ import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
from coati.distributed.agent.langgraph_math_agentic import LangGraphMathAgenticProducer
|
||||
from coati.distributed.agent.qwen_math_agentic import QwenMathAgenticProducer
|
||||
from coati.distributed.agent.agentic_producer import AgenticProducer
|
||||
from coati.distributed.agent.qwen_math_agentic_producer import QwenMathAgenticProducer
|
||||
from coati.distributed.agent.tool_worker import ToolWorker
|
||||
|
||||
from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer
|
||||
@@ -21,7 +22,7 @@ ALGO_MAP = {
|
||||
Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncSimpleProducer}
|
||||
AGENTIC_PRODUCER_MAP = {
|
||||
"QwenMathAgent": QwenMathAgenticProducer,
|
||||
"LangGraphMathAgent": LangGraphMathAgenticProducer,
|
||||
"Agentic": AgenticProducer,
|
||||
} # supported agentic producers
|
||||
|
||||
|
||||
@@ -165,21 +166,16 @@ def launch_distributed(
|
||||
)
|
||||
producer_procs.append(producer)
|
||||
ray.get([p.setup.remote() for p in producer_procs])
|
||||
"""
|
||||
# test async generate
|
||||
import torch
|
||||
import asyncio
|
||||
import time
|
||||
async def test():
|
||||
res_ref = producer_procs[0].generate.remote(torch.ones((2, 10), dtype=torch.int), torch.ones((2, 10), dtype=torch.int))
|
||||
res = await res_ref
|
||||
return res
|
||||
res = asyncio.run(test())
|
||||
print(res)
|
||||
time.sleep(1000)
|
||||
"""
|
||||
|
||||
if enable_agentic:
|
||||
from coati.distributed.agent.math_tools import repl_tool
|
||||
|
||||
# setup tool workers
|
||||
tool_workers = []
|
||||
if agentic_config["agentic_producer"] == "Agentic":
|
||||
# 10 tool workers can handle 50 queries simultaneously
|
||||
# note that imported repl_tool will be serialized and deserialized in each tool worker, therefore all workers can run parallely
|
||||
tool_workers = [ToolWorker.remote([repl_tool]) for _ in range(agentic_config.get("num_tool_workers", 10))]
|
||||
# when agentic is enabled, we use core_producer as inference engine and
|
||||
# AgenticProducer as the real producer
|
||||
_producer_procs = producer_procs
|
||||
@@ -194,7 +190,7 @@ def launch_distributed(
|
||||
producer_procs = [
|
||||
agentic_producer_cls.options(num_cpus=1).remote(
|
||||
producer_idx=producer_idx,
|
||||
num_producers=num_producers * train_batch_size,
|
||||
num_producers=num_producers * inference_batch_size,
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
num_episodes=num_episodes,
|
||||
batch_size=1, # batch_size must be 1 for agentic producer
|
||||
@@ -202,6 +198,7 @@ def launch_distributed(
|
||||
model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
async_producers=_producer_procs,
|
||||
tool_workers=tool_workers,
|
||||
tokenizer_config=tokenizer_config,
|
||||
agentic_config=agentic_config,
|
||||
microbatch_size=1, # microbatch_size must be 1 for agentic producer
|
||||
|
||||
@@ -587,26 +587,6 @@ class BaseAsyncProducer(BaseProducer):
|
||||
self.condition = asyncio.Condition()
|
||||
self.data_ready_for_sending = []
|
||||
|
||||
# @torch.no_grad()
|
||||
# async def generate(self, input_ids, attention_mask, **kwargs):
|
||||
# tasks = []
|
||||
# print("input_ids:", input_ids)
|
||||
# for prompt_id in range(input_ids.size(0)):
|
||||
# new_kwargs = copy.deepcopy(kwargs)
|
||||
# if "gt_answer" in new_kwargs:
|
||||
# new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id]
|
||||
# if "test_cases" in new_kwargs:
|
||||
# new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id]
|
||||
# tasks.append(
|
||||
# self.model.generate(
|
||||
# input_ids[prompt_id].unsqueeze(0),
|
||||
# attention_mask[prompt_id].unsqueeze(0),
|
||||
# **new_kwargs,
|
||||
# )
|
||||
# )
|
||||
# rollouts = await asyncio.gather(*tasks)
|
||||
# return rollouts
|
||||
|
||||
@torch.no_grad()
|
||||
async def generate(self, input_ids, attention_mask, **kwargs):
|
||||
# naive rollout strategy
|
||||
@@ -647,7 +627,7 @@ class BaseAsyncProducer(BaseProducer):
|
||||
"""
|
||||
Get the load of each producer.
|
||||
"""
|
||||
return len(self.model.queued_requests)
|
||||
return len(self.model.running_requests)
|
||||
|
||||
async def async_sync_model(self, episode, step, num_processes: int = 1) -> None:
|
||||
"""
|
||||
|
||||
@@ -164,8 +164,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--agentic-type",
|
||||
type=str,
|
||||
default="QwenMathAgent",
|
||||
choices=["QwenMathAgent", "LangGraphMathAgent"],
|
||||
default="Agentic",
|
||||
choices=["Agentic", "QwenMathAgent"],
|
||||
help="Agentic model type for agentic training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -418,9 +418,6 @@ if __name__ == "__main__":
|
||||
if "agentic" in args.backend:
|
||||
assert "vllm" in args.backend, "Agentic backend only supports async-agentic-vllm backends."
|
||||
generate_config["n"] = 1 # agentic producer use AsyncProducer which processes one request a time
|
||||
generate_config["max_tokens"] = (
|
||||
2048 # max new tokens for each agentic step, usually smaller than max_new_tokens as agentic model will generate multiple steps
|
||||
)
|
||||
if args.agentic_type == "QwenMathAgent":
|
||||
agentic_config = {
|
||||
"agentic_producer": "QwenMathAgent",
|
||||
@@ -433,8 +430,15 @@ if __name__ == "__main__":
|
||||
agentic_config["generate_cfg"].update(
|
||||
{k: v for k, v in generate_config.items() if k in ["top_k", "top_p", "temperature"]}
|
||||
)
|
||||
elif args.agentic_type == "LangGraphMathAgent":
|
||||
agentic_config = {"configurable": {"thread_id": "math-1"}, "agentic_producer": "LangGraphMathAgent"}
|
||||
elif args.agentic_type == "Agentic":
|
||||
generate_config["stop"] = ["<|im_end|>"]
|
||||
generate_config["prompt_logprobs"] = 0
|
||||
agentic_config = {
|
||||
"agentic_producer": "Agentic",
|
||||
"tool_call_budget": 5,
|
||||
"llm_call_budget": 10,
|
||||
"max_tokens": 2048,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported agentic model type: {args.agentic_type}")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user