support agentic with asyncllm

This commit is contained in:
YeAnbang
2025-09-03 15:12:46 +08:00
parent d49a28dad0
commit 7f814e71f3
12 changed files with 947 additions and 439 deletions

View File

@@ -4,6 +4,7 @@
Dataloader for sft, dpo, ppo
"""
import copy
import os
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Sequence, Union
@@ -423,7 +424,9 @@ class RawConversationDataset(Dataset):
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
"""
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None:
def __init__(
self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str, tokenize=True
) -> None:
self.tokenizer = tokenizer
self.raw_texts = []
with jsonlines.open(input_file) as f:
@@ -432,30 +435,50 @@ class RawConversationDataset(Dataset):
self.tokenized_texts = [None] * len(self.raw_texts)
self.max_length = max_length
self.system_prompt = system_prompt
self.tokenize = tokenize
def __len__(self) -> int:
return len(self.raw_texts)
def __getitem__(self, index: int):
if self.tokenized_texts[index] is None:
message = self.raw_texts[index]
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]
if self.tokenize:
if self.tokenized_texts[index] is None:
message = self.raw_texts[index]
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]
else:
chat = copy.deepcopy(self.raw_texts[index])
chat["messages"] = [{"role": "system", "content": self.system_prompt}, chat["messages"]]
return chat
def collate_fn_grpo(batch):
input_ids = [item["input_ids"] for item in batch]
attention_mask = [item["attention_mask"] for item in batch]
labels = [item["labels"] for item in batch]
# Assume input_ids, attention_mask, labels are already of the same length,
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
input_ids = torch.stack(input_ids)
attention_mask = torch.stack(attention_mask)
labels = torch.stack(labels)
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
if "test_cases" in batch[0]:
ret["test_cases"] = [item["test_cases"] for item in batch]
if "gt_answer" in batch[0]:
ret["gt_answer"] = [item["gt_answer"] for item in batch]
return ret
if "input_ids" in batch[0]:
# tokenized format
input_ids = [item["input_ids"] for item in batch]
attention_mask = [item["attention_mask"] for item in batch]
labels = [item["labels"] for item in batch]
# Assume input_ids, attention_mask, labels are already of the same length,
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
input_ids = torch.stack(input_ids)
attention_mask = torch.stack(attention_mask)
labels = torch.stack(labels)
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
if "test_cases" in batch[0]:
ret["test_cases"] = [item["test_cases"] for item in batch]
if "gt_answer" in batch[0]:
ret["gt_answer"] = [item["gt_answer"] for item in batch]
return ret
elif "messages" in batch[0]:
# vllm format
ret = {
"messages": [item["messages"] for item in batch],
}
if "test_cases" in batch[0]:
ret["test_cases"] = [item["test_cases"] for item in batch]
if "gt_answer" in batch[0]:
ret["gt_answer"] = [item["gt_answer"] for item in batch]
return ret
else:
raise ValueError("Unsupported batch format")

View File

@@ -0,0 +1,199 @@
import copy
import json
from typing import Any, Dict
import ray
import torch
from coati.distributed.agent.agentic_math_utils import TIR_SYSTEM, CustomTransformers
from coati.distributed.producer import BaseProducer
from qwen_agent.agents import TIRMathAgent
from vllm import SamplingParams
@ray.remote
class AgenticProducer(BaseProducer):
"""
Asyncronous version of the producer that uses vLLM for generation.
This class is designed to generate agentic response
"""
def __init__(
self,
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
async_producers,
tokenizer_config=None,
agentic_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
n_behind: int = 0,
):
assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
assert batch_size == 1 # batch_size must be 1 for agentic producer
super().__init__(
producer_idx,
num_producers,
num_consumer_procs,
num_episodes,
batch_size,
train_dataset_config,
model_config,
generate_config,
tokenizer_config,
microbatch_size,
backend,
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
n_behind=n_behind,
enable_agentic=True,
)
self.eval_generation_config = copy.deepcopy(generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
self.async_producers = async_producers
self.num_generations = num_generations
self.generate_config = generate_config
self.agentic_config = model_config if not agentic_config else agentic_config
self.agentic_config.update({"model": model_config["path"]})
self.llm = CustomTransformers(self.agentic_config, self.producer_idx, generation_workers=self.async_producers)
self.bot = TIRMathAgent(llm=self.llm, name=model_config["path"], system_message=TIR_SYSTEM)
def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
"""
Rollout function to generate responses for the input, for example, using LLM or agentic pipeline.
This function should be implemented in subclasses.
"""
assert len(kwargs["messages"]) == 1, "Only support batch size of 1 for agentic producer"
messages = kwargs["messages"][0]
prompt_input_ids = self.tokenizer.apply_chat_template(
messages, return_tensors="pt", tokenize=True, add_generation_prompt=True
)
# add left padding
prompt_length = prompt_input_ids.shape[1]
max_prompt_length = self.train_dataset_config["max_length"]
to_pad_left = max_prompt_length - prompt_length
rollouts = {
"input_ids": [],
"attention_mask": [],
"action_mask": [],
"action_log_probs": [],
"response_idx": [],
}
for i in range(self.num_generations):
_messages = copy.deepcopy(messages)
for response in self.bot.run(messages):
continue
_messages.extend(response)
response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True)
# truncate if too long
response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left]
# add left right padding
to_pad_right = self.grpo_config["max_length"] - response_input_ids.shape[1] - to_pad_left
response_length = response_input_ids.shape[1] - prompt_length
input_ids = torch.nn.functional.pad(
response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id
) # [1, max_length]
attention_mask = torch.nn.functional.pad(
torch.ones_like(response_input_ids), (to_pad_left, to_pad_right), "constant", value=0
) # [1, max_length]
action_mask = torch.nn.functional.pad(
torch.ones(size=(1, response_length)), (0, to_pad_right), "constant", value=0
) # [1, max_length-prompt_length]
rollouts["attention_mask"].append(attention_mask)
rollouts["action_mask"].append(action_mask)
rollouts["action_log_probs"].append(
torch.ones(size=(1, self.grpo_config["max_length"] - max_prompt_length))
) # dummy log probs
rollouts["response_idx"].append(
torch.tensor(
[
[
self.train_dataset_config["max_length"],
self.train_dataset_config["max_length"] + response_length,
]
]
)
) # [1, 2]
rollouts["input_ids"].append(input_ids)
# breakpoint()
rollouts = {k: torch.cat(v, dim=0).unsqueeze(0) for k, v in rollouts.items()} # [num_generations, ...]
rollouts["temperature"] = torch.tensor([self.agentic_config.get("temperature", 1.0)])
if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode:
# for agentic producer, AsyncSimpleProducer is not the main producer, so we don't log rollouts
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
or self.latest_rollout_log_step == -1
):
new_record = (
json.dumps(
{
"train_step": self.consumer_global_step,
"rollout": self.tokenizer.batch_decode(
rollouts["input_ids"][:, 0], skip_special_tokens=True
),
}
)
+ "\n"
)
self.rollout_log_file.write(new_record)
self.rollout_log_file.flush()
self.latest_rollout_log_step = self.consumer_global_step
if "gt_answer" in kwargs:
rollouts["gt_answer"] = kwargs["gt_answer"]
if "test_cases" in kwargs:
rollouts["test_cases"] = kwargs["test_cases"]
return rollouts
def sync_model(self, episode, step) -> None:
"""
sync model from consumer to self.async_producers
AgenticProducer does not hold any model weights, so no need to sync model to self.async_producers
"""
tasks = []
for proc in self.async_producers:
tasks.append(proc.async_sync_model.remote(episode, step, self.num_producers))
ray.get(tasks)
return
def sync_data(self, data: Dict[str, torch.Tensor]) -> None:
"""
sync data from self to consumer
"""
tasks = []
for idx, proc in enumerate(self.async_producers):
if idx == self.producer_idx % len(self.async_producers):
tasks.append(proc.async_sync_data.remote(data, self.num_producers))
else:
tasks.append(proc.async_sync_data.remote({}, self.num_producers))
ray.get(tasks)
return

View File

@@ -0,0 +1,170 @@
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A TIR(tool-integrated reasoning) math agent
```bash
python tir_math.py
```
"""
import os
import random
import ray
from qwen_agent.agents import TIRMathAgent
from qwen_agent.llm.base import register_llm
from qwen_agent.llm.function_calling import BaseFnCallModel
from qwen_agent.llm.transformers_llm import Transformers
from qwen_agent.log import logger
from transformers import AutoTokenizer
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
# We use the following two systems to distinguish between COT mode and TIR mode
TIR_SYSTEM = """Please integrate natural language reasoning with programs to solve the problem above, and put your final answer within \\boxed{}."""
COT_SYSTEM = """Please reason step by step, and put your final answer within \\boxed{}."""
from transformers import StoppingCriteria
tokenizer = AutoTokenizer.from_pretrained("/mnt/nfs/share/data/model/Qwen2.5-Math-7B-Instruct", trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_token_ids):
self.stop_token_ids = stop_token_ids
def __call__(self, input_ids, scores, **kwargs):
# Check if the last token is one of the stop tokens
if input_ids[0, -1].item() in self.stop_token_ids:
return True
return False
class LocalLLMFromGenerationWorkers:
"""
A class that wraps the Transformers model to support API-based text generation.
"""
def __init__(self, generation_worker=None):
self.device = "cpu"
self.generation_worker = generation_worker
def generate(self, **kwargs):
rollouts = ray.get(self.generation_worker.generate.remote(**kwargs))
return rollouts["input_ids"]
@register_llm("api_based_transformers")
class CustomTransformers(Transformers):
"""
Transformers class that supports API-based text generation.
"""
def __init__(self, cfg: dict, producer_idx, generation_workers=None):
BaseFnCallModel.__init__(self, cfg) # skip the super() init of Transformers to avoid loading hf model
############ Setup logic from Transformers.__init__ ###############
if "model" not in cfg:
raise ValueError("Please provide the model id or directory through `model` in cfg.")
try:
from transformers import AutoConfig, AutoProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast
except ImportError as e:
raise ImportError(
"Could not import classes from transformers. " "Please install it with `pip install -U transformers`"
) from e
self.hf_config = AutoConfig.from_pretrained(cfg["model"])
arch = self.hf_config.architectures[0]
if len(self.hf_config.architectures) > 1:
logger.warning(
f"The config for the transformers model type contains more than one architecture, choosing the first: {arch}"
)
# try loading a processor, if got a tokenizer, regarding the model as text-only
processor = AutoProcessor.from_pretrained(cfg["model"])
if isinstance(processor, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
logger.info(f"Regarding the transformers model as text-only since its processor is a tokenizer.")
self.tokenizer = processor
self._support_multimodal_input = False
else:
self.processor = processor
self.tokenizer = self.processor.tokenizer
self._support_multimodal_input = True
################################################################
self.generation_workers = generation_workers
self.hf_models = [
LocalLLMFromGenerationWorkers(generation_worker=generation_worker)
for generation_worker in generation_workers
]
self.producer_idx = producer_idx
self.load_balancer_idx = producer_idx % len(self.generation_workers)
@property
def hf_model(self):
# Simple round-robin load balancing
model = self.hf_models[self.load_balancer_idx]
return model
def _chat_stream(
self,
messages,
delta_stream: bool,
generate_cfg: dict,
):
# overwrite streaming because streamer is not serializable
# determine load balancer idx based on producer load, refresh every generation
load = [ray.get(generation_worker.get_producer_load.remote()) for generation_worker in self.generation_workers]
min_load = min(load)
candidates = [i for i, l in enumerate(load) if l == min_load]
# random tie break
self.load_balancer_idx = random.choice(candidates)
response = self._chat_no_stream(messages=messages, generate_cfg=generate_cfg)
# if self.producer_idx == 0:
# print(response)
yield response
def init_agent_service():
llm_cfg = {
# Use the OpenAI-compatible model service provided by DashScope:
"model": "/mnt/nfs/share/data/model/Qwen2.5-Math-7B-Instruct",
"model_type": "transformers",
"generate_cfg": {
# Using the API's native tool call interface
"top_k": 1,
},
}
llm = CustomTransformers(llm_cfg)
bot = TIRMathAgent(llm=llm, name="Qwen2.5-Math", system_message=TIR_SYSTEM)
return bot
def app_tui():
# Define the agent
bot = init_agent_service()
# Chat
messages = []
while True:
# Query example: 斐波那契数列前10个数字
query = input("user question: ")
messages.append({"role": "user", "content": query})
response = []
for response in bot.run(messages):
print("bot response:", response)
messages.extend(response)
# if __name__ == '__main__':
# # Test the TIR math agent locally
# app_tui()

View File

@@ -1,149 +0,0 @@
"""
MIT License
Copyright (c) 2025 LangChain
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from typing import Any, Dict, Iterator, List, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
class LangChainChatModel(BaseChatModel):
"""A custom chat model that echoes the first `parrot_buffer_length` characters
of the input.
When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.
Example:
.. code-block:: python
model = LangChainChatModel(parrot_buffer_length=2, model="bird-brain-001")
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""
model_name: str = Field(alias="model")
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[List[str]] = None
async_server_manager: Optional[Any] = None
max_retries: int = 2
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override the _generate method to implement the chat model logic.
This can be a call to an API, a call to a local model, or any other
implementation that generates a response to the input prompt.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
self.async_server_manager.generate(messages, stop, run_manager, **kwargs)
tokens = last_message.content[: self.parrot_buffer_length]
ct_input_tokens = sum(len(message.content) for message in messages)
ct_output_tokens = len(tokens)
message = AIMessage(
content=tokens,
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
"model_name": self.model_name,
},
usage_metadata={
"input_tokens": ct_input_tokens,
"output_tokens": ct_output_tokens,
"total_tokens": ct_input_tokens + ct_output_tokens,
},
)
##
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model.
This method should be implemented if the model can generate output
in a streaming fashion. If the model does not support streaming,
do not implement it. In that case streaming requests will be automatically
handled by the _generate method.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
raise NotImplementedError("Streaming is not implemented for this model. Please implement the _stream method.")
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "echoing-chat-model-advanced"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
}

View File

@@ -0,0 +1,126 @@
# -------------------------------
# 1. Define the Python tool
# -------------------------------
import io
import sys
from typing import Dict, List
import requests
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
class Capturing(list):
"""Capture stdout prints inside exec()"""
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = io.StringIO()
return self
def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
sys.stdout = self._stdout
@tool
def python(code: str) -> str:
"""
This function executes a string of Python code and returns the printed output.
You need to print the output. Please import all libraries used in the code string.
"""
local_vars = {}
with Capturing() as output:
exec(code, {}, local_vars)
if output == []:
return "Error: No output printed from the code. Please ensure you print the output."
return "\n".join(output)
# -------------------------------
# 2. Define a Custom API LLM wrapper
# -------------------------------
class CustomAPILLM:
def __init__(self, api_url: str, api_key: str = None):
self.api_url = api_url
self.api_key = api_key
def invoke(self, messages: List[Dict[str, str]]) -> str:
"""
messages: list of {"role": "user"/"assistant"/"system", "content": "..."}
"""
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"model": "custom-model", # depends on your API
"messages": messages,
"temperature": 0,
}
response = requests.post(self.api_url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
# Adjust according to your API response format
return data["choices"][0]["message"]["content"]
# -------------------------------
# 3. Build a ReAct Agent with LangGraph
# -------------------------------
def build_agent():
# Wrap custom API LLM in LangChain-compatible interface
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
class LangChainCustomLLM(BaseChatModel):
client: CustomAPILLM = None
def __init__(self, client: CustomAPILLM):
super().__init__()
self.client = client
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
content = self.client.invoke([m.dict() for m in messages])
return self._create_chat_result([AIMessage(content=content)])
@property
def _llm_type(self) -> str:
return "custom-api-llm"
# Init LLM
llm_client = CustomAPILLM(api_url="http://localhost:8000/v1/chat/completions")
llm = LangChainCustomLLM(llm_client)
# Tools
tools = [python]
# Memory (optional)
memory = MemorySaver()
# Build ReAct agent
agent = create_react_agent(llm, tools, checkpointer=memory)
return agent
# -------------------------------
# 4. Run the agent on a math problem
# -------------------------------
if __name__ == "__main__":
agent = build_agent()
# Example math question
user_input = "What is the least common multiple of 18 and 24? Use Python if needed."
config = {"configurable": {"thread_id": "math-1"}}
for event in agent.stream({"messages": [("user", user_input)]}, config):
if "agent" in event:
print("Agent event:", event["agent"]["messages"][-1].content)
elif "tools" in event:
print("Tool event:", event["tools"]["messages"][-1].content)
final_state = agent.get_state(config)
print("Final Answer:", final_state["messages"][-1].content)

View File

@@ -1,112 +0,0 @@
"""
MIT License
Copyright (c) 2025 LangChain
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import builtins
import contextlib
import io
import math
from typing import Any
def eval(code: str, _locals: dict[str, Any]) -> tuple[str, dict[str, Any]]:
# Store original keys before execution
original_keys = set(_locals.keys())
try:
with contextlib.redirect_stdout(io.StringIO()) as f:
exec(code, builtins.__dict__, _locals)
result = f.getvalue()
if not result:
result = "<code ran, no output printed to stdout>"
except Exception as e:
result = f"Error during execution: {repr(e)}"
# Determine new variables created during execution
new_keys = set(_locals.keys()) - original_keys
new_vars = {key: _locals[key] for key in new_keys}
return result, new_vars
def add(a: float, b: float) -> float:
"""Add two numbers together."""
return a + b
def multiply(a: float, b: float) -> float:
"""Multiply two numbers together."""
return a * b
def divide(a: float, b: float) -> float:
"""Divide two numbers."""
return a / b
def subtract(a: float, b: float) -> float:
"""Subtract two numbers."""
return a - b
def sin(a: float) -> float:
"""Take the sine of a number."""
return math.sin(a)
def cos(a: float) -> float:
"""Take the cosine of a number."""
return math.cos(a)
def radians(a: float) -> float:
"""Convert degrees to radians."""
return math.radians(a)
def exponentiation(a: float, b: float) -> float:
"""Raise one number to the power of another."""
return a**b
def sqrt(a: float) -> float:
"""Take the square root of a number."""
return math.sqrt(a)
def ceil(a: float) -> float:
"""Round a number up to the nearest integer."""
return math.ceil(a)
tools = [
add,
multiply,
divide,
subtract,
sin,
cos,
radians,
exponentiation,
sqrt,
ceil,
]

View File

@@ -150,6 +150,7 @@ class BaseConsumer:
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
state_dict = self.state_dict()
print(f"[C{self.rank}]: Sync model before training")
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
ray_broadcast_tensor_dict(
@@ -164,6 +165,7 @@ class BaseConsumer:
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
del state_dict
print(f"[C{self.rank}]: Sync model before training done")
torch.cuda.empty_cache()
self.profiler.exit("sync_model")
@@ -323,7 +325,7 @@ class BaseConsumer:
) # for setting start index when resuming training
if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")
# breakpoint()
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
episode != 0 or step >= self.n_behind
):

View File

@@ -64,7 +64,7 @@ class AsyncInferenceBackend(BaseInferenceBackend):
- action_mask (torch.Tensor): shape [B, N]
where N is the number of generated tokens. And all tensors should be on CUDA.
"""
raise NotImplementedError("AsyncInferenceBackend does not support generate method.")
raise NotImplementedError("Generate method must be implemented in subclass.")
class TransformersInferenceBackend(BaseInferenceBackend):
@@ -84,6 +84,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
microbatch_size: int = 1,
profiler=None,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
@@ -93,6 +94,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer
self.num_generations = num_generations
self.profiler = profiler
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
@@ -158,6 +160,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
microbatch_size: int = 1,
profiler=None,
):
if sgl is None:
raise ImportError("sglang is not installed")
@@ -223,6 +226,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
microbatch_size: int = 1,
profiler=None,
):
if LLM is None:
raise ImportError("vllm is not installed")
@@ -323,6 +327,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
microbatch_size: int = 1,
profiler=None,
):
if LLM is None:
raise ImportError("vllm is not installed")
@@ -332,7 +337,8 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
if "n" not in generate_config:
generate_config.update({"n": num_generations})
self.generate_config = generate_config
self.sample_params = SamplingParams(**generate_config)
self.model_config = model_config
@@ -340,6 +346,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
self.num_generations = num_generations
self.queued_requests = []
self.microbatch_size = microbatch_size
self.profiler = profiler
@torch.no_grad()
async def generate(
@@ -351,6 +358,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
input_ids (torch.Tensor): shape [B, S], B=1
attention_mask (torch.Tensor): shape [B, S]
"""
# breakpoint()
assert input_ids.size(0) == attention_mask.size(0) == 1
response_start_idx = input_ids.size(1)
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
@@ -366,6 +374,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
self.queued_requests.append(request_id) # enqueue
# pop the first input_ids and attention_mask
prompt_token_ids = input_ids_no_padding[0]
self.profiler.enter(f"vllm generate {request_id}")
outputs = self.engine.generate(
prompt={"prompt_token_ids": prompt_token_ids}, sampling_params=sample_params, request_id=request_id
)
@@ -380,6 +389,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
assert len(output_i.logprobs) == len(output_i.token_ids)
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
log_probs.append(p)
self.profiler.exit(f"vllm generate {request_id}")
# pad them
max_len = self.sample_params.max_tokens
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)

View File

@@ -4,13 +4,13 @@ import uuid
from typing import Any, Dict, Optional
import ray
from coati.distributed.agent.agentic import AgenticProducer
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
from .producer import AsyncProducer, SimpleProducer
from .producer import AsyncSimpleProducer, SimpleProducer
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncProducer}
def get_jsonl_size_fast(path: str) -> int:
@@ -42,6 +42,7 @@ def launch_distributed(
generate_config: Dict[str, Any],
train_model_config: Dict[str, Any],
grpo_config: Dict[str, Any],
agentic_config: Optional[Dict[str, Any]],
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
@@ -73,7 +74,7 @@ def launch_distributed(
num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size if "async" not in inference_backend else 1
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
wandb_group_name = str(uuid.uuid4())
@@ -105,9 +106,12 @@ def launch_distributed(
producer_procs = []
if "async" in inference_backend:
core_producer = AsyncProducer
core_producer = AsyncSimpleProducer
else:
core_producer = Producer_MAP.get("Simple", SimpleProducer)
core_producer = SimpleProducer
enable_agentic = "agentic" in inference_backend
if enable_agentic:
inference_backend = inference_backend.replace("agentic-", "")
for i in range(num_producers):
node_id = gpu_to_node_id[0]
producer_ip_address = gpu_to_ip_address[0]
@@ -125,7 +129,11 @@ def launch_distributed(
model_config=inference_model_config,
generate_config=generate_config,
tokenizer_config=tokenizer_config,
microbatch_size=inference_microbatch_size,
microbatch_size=(
inference_microbatch_size * num_generations
if "async" in inference_backend
else inference_microbatch_size
),
backend=inference_backend,
num_generations=num_generations,
consumer_plugin_config=plugin_config,
@@ -138,12 +146,63 @@ def launch_distributed(
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
rollout_log_file=rollout_log_file if not enable_agentic else None,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
"""
# test async generate
import torch
import asyncio
import time
async def test():
res_ref = producer_procs[0].generate.remote(torch.ones((2, 10), dtype=torch.int), torch.ones((2, 10), dtype=torch.int))
res = await res_ref
return res
res = asyncio.run(test())
print(res)
time.sleep(1000)
"""
if enable_agentic:
# when agentic is enabled, we use core_producer as inference engine and
# AgenticProducer as the real producer
_producer_procs = producer_procs
producer_procs = [
AgenticProducer.options(num_cpus=1).remote(
producer_idx=producer_idx,
num_producers=num_producers * train_batch_size,
num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes,
batch_size=1, # batch_size must be 1 for agentic producer
train_dataset_config=train_dataset_config,
model_config=inference_model_config,
generate_config=generate_config,
async_producers=_producer_procs,
tokenizer_config=tokenizer_config,
agentic_config=agentic_config,
microbatch_size=1, # microbatch_size must be 1 for agentic producer
backend=inference_backend,
num_generations=num_generations,
consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
grpo_config=grpo_config,
eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
for producer_idx in range(num_producers * inference_batch_size)
]
generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update(
dict(

View File

@@ -57,6 +57,7 @@ class BaseProducer:
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
enable_agentic: bool = False,
n_behind: int = 0,
):
self.producer_idx = producer_idx
@@ -65,8 +66,12 @@ class BaseProducer:
self.num_episodes = num_episodes
self.batch_size = batch_size
self.microbatch_size = microbatch_size
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
if not isinstance(self, BaseAsyncProducer):
assert batch_size % microbatch_size == 0, "batch_size must be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size
else:
assert microbatch_size > 0, "microbatch_size must be positive"
self.num_microbatches = max(1, batch_size // microbatch_size)
self.latest_eval_step = -1
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
@@ -84,13 +89,14 @@ class BaseProducer:
self.latest_rollout_log_step = -1
self.grpo_config = grpo_config
self.n_behind = n_behind
self.enable_agentic = enable_agentic
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"]
}
self.response_format_tags = grpo_config.get("response_format_tags", None)
if producer_idx == 0:
if producer_idx == 0 and rollout_log_file is not None:
if os.path.exists(rollout_log_file):
raise ValueError(
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
@@ -121,7 +127,9 @@ class BaseProducer:
# init dataloader
train_dataset_path = train_dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
self.train_dataset = RawConversationDataset(
self.tokenizer, train_dataset_path, **train_dataset_config, tokenize=not self.enable_agentic
)
self.train_dataloader = DataLoader(
self.train_dataset,
batch_size=microbatch_size,
@@ -159,7 +167,10 @@ class BaseProducer:
for eval_task_name in self.eval_dataset_config:
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
eval_dataset = RawConversationDataset(
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
self.tokenizer,
eval_dataset_path,
**eval_dataset_config[eval_task_name],
tokenize=not self.enable_agentic,
)
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
self.eval_dataloaders[eval_task_name] = DataLoader(
@@ -207,18 +218,34 @@ class BaseProducer:
else:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
"""
Generate responses by running inference on the input_ids and attention_mask.
"""
return self.model.generate(input_ids, attention_mask, **kwargs)
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
"""
Rollout function to generate responses for the input, for example, using LLM or agentic pipeline.
This function should be implemented in subclasses.
"""
raise NotImplementedError
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
raise NotImplementedError
def loop(self) -> None:
def sync_model(self, episode, step) -> None:
"""
Default implementation to sync model from consumer to producer.
"""
torch.cuda.empty_cache()
self.profiler.enter("sync_model")
if self.consumer_pp_size > 1:
for pp_idx in range(self.consumer_pp_size):
print(
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(step + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
@@ -226,6 +253,7 @@ class BaseProducer:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
print(f"[P{self.producer_idx}] Sync model episode {episode} step {(step + 1) // self.num_microbatches - 1}")
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
@@ -233,10 +261,18 @@ class BaseProducer:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
print(f"[P{self.producer_idx}] Sync initial model done.")
del state_dict
torch.cuda.empty_cache()
def sync_data(self, data: Dict[str, torch.Tensor]) -> None:
"""
Default implementation to sync data from producer to consumer.
"""
ray_broadcast_tensor_dict(data, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}")
def loop(self) -> None:
# breakpoint()
self.sync_model(0, 0)
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches
@@ -308,14 +344,12 @@ class BaseProducer:
self.eval_mode = False
self.latest_eval_step = self.consumer_global_step
self.profiler.enter("rollout")
if isinstance(self.model, BACKEND_MAP["async-vllm"]):
outputs = asyncio.run(self.rollout(**batch))
else:
outputs = self.rollout(**batch)
outputs = self.rollout(**batch)
self.profiler.exit("rollout")
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
if "temperature" not in outputs:
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
self.profiler.enter("calculate_reward")
if self.grpo_config["reward_fn_type"] == "code":
@@ -360,52 +394,16 @@ class BaseProducer:
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)
self.profiler.enter("send_broadcast_data")
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
)
self.sync_data(outputs)
self.profiler.exit("send_broadcast_data")
if (
(i + 1) % self.num_microbatches == 0
and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
):
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration
torch.cuda.empty_cache()
self.profiler.enter("sync_model")
if self.consumer_pp_size > 1:
for pp_idx in range(self.consumer_pp_size):
print(
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
del state_dict
torch.cuda.empty_cache()
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.wake_up()
self.sync_model(episode, i)
# linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0:
if episode <= 0 and hasattr(self, "model"):
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
"temperature"
@@ -475,7 +473,7 @@ class SimpleProducer(BaseProducer):
n_behind=n_behind,
)
self.model = self.backend_cls(
model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size
model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size, profiler=self.profiler
)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
@@ -484,7 +482,7 @@ class SimpleProducer(BaseProducer):
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
rollouts = self.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 0 and not self.eval_mode:
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
@@ -516,8 +514,7 @@ class SimpleProducer(BaseProducer):
self.model.load_state_dict(state_dict)
@ray.remote
class AsyncProducer(BaseProducer):
class BaseAsyncProducer(BaseProducer):
"""
Asyncronous version of the producer that uses vLLM for generation.
"""
@@ -577,15 +574,39 @@ class AsyncProducer(BaseProducer):
)
assert backend == "async-vllm", f"AsyncProducer only supports async-vllm backend, got {backend}"
self.model = self.backend_cls(
model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size
model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size, profiler=self.profiler
)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
self.ready_processes = 0
self.condition = asyncio.Condition()
self.data_ready_for_sending = []
# @torch.no_grad()
# async def generate(self, input_ids, attention_mask, **kwargs):
# tasks = []
# print("input_ids:", input_ids)
# for prompt_id in range(input_ids.size(0)):
# new_kwargs = copy.deepcopy(kwargs)
# if "gt_answer" in new_kwargs:
# new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id]
# if "test_cases" in new_kwargs:
# new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id]
# tasks.append(
# self.model.generate(
# input_ids[prompt_id].unsqueeze(0),
# attention_mask[prompt_id].unsqueeze(0),
# **new_kwargs,
# )
# )
# rollouts = await asyncio.gather(*tasks)
# return rollouts
@torch.no_grad()
async def rollout(self, input_ids, attention_mask, **kwargs):
async def generate(self, input_ids, attention_mask, **kwargs):
# naive rollout strategy
tasks = []
for prompt_id in range(input_ids.size(0)):
new_kwargs = copy.deepcopy(kwargs)
@@ -600,37 +621,224 @@ class AsyncProducer(BaseProducer):
**new_kwargs,
)
)
# print(f"Producer {self.producer_idx} running {len(tasks)} tasks")
rollouts = await asyncio.gather(*tasks)
rollouts = {
k: (
torch.cat([r[k] for r in rollouts], dim=0)
if k not in ["gt_answer", "test_cases"]
else [r[k] for r in rollouts]
)
).cpu() # CUDA tensor is not serializable by ray
for k in rollouts[0].keys()
}
if self.producer_idx == 0 and not self.eval_mode:
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
or self.latest_rollout_log_step == -1
):
new_record = (
json.dumps(
{
"train_step": self.consumer_global_step,
"rollout": self.tokenizer.batch_decode(
rollouts["input_ids"][:, 0], skip_special_tokens=True
),
}
)
+ "\n"
)
self.rollout_log_file.write(new_record)
self.rollout_log_file.flush()
self.latest_rollout_log_step = self.consumer_global_step
return rollouts
@torch.no_grad()
async def rollout(self, input_ids, attention_mask, **kwargs):
"""
Advanced distributed rollout strategy that dispatches the generation tasks to different DP ranks.
Must be implemented in subclasses.
"""
raise NotImplementedError("rollout must be implemented in subclasses")
async def get_producer_load(self):
"""
Get the load of each producer.
"""
return len(self.model.queued_requests)
async def async_sync_model(self, episode, step, num_processes: int = 1) -> None:
"""
Asyncronous version to sync model from consumer to producer.
called by another producer, such as agentic producer.
"""
async with self.condition:
self.ready_processes += 1
# Wait until all processes are ready
if self.ready_processes < num_processes:
await self.condition.wait()
# Only one process should reset `ready_processes` and perform the sync
if self.ready_processes == num_processes:
self.ready_processes = 0
self.condition.notify_all() # Notify all waiting processes
self.sync_model(episode, step)
async def async_sync_data(self, data: Dict[str, torch.Tensor], num_processes: int = 1) -> None:
# merge data dict
async with self.condition:
self.ready_processes += 1
if data:
self.data_ready_for_sending.append(data)
# Wait until all processes are ready
if self.ready_processes < num_processes:
await self.condition.wait()
# Only one process should reset `ready_processes` and perform the sync
if self.ready_processes == num_processes: # wait for all producers to join
self.ready_processes = 0
self.condition.notify_all()
# merge data for sending
if len(self.data_ready_for_sending) >= 1:
batch_rollout_data = {}
for key in self.data_ready_for_sending[0]:
batch_rollout_data[key] = torch.cat([d[key] for d in self.data_ready_for_sending], dim=0).to(
self.device
)
self.sync_data(batch_rollout_data)
self.data_ready_for_sending = [] # reset
async def loop(self) -> None:
self.sync_model(0, 0)
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches
print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
)
for episode in range(self.num_episodes):
self.train_dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.train_dataloader):
if i >= num_valid_microbatches:
break
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if (
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
and self.consumer_global_step > self.latest_eval_step
) or self.latest_eval_step == -1:
to_log_msg = {}
self.eval_mode = True
for eval_task_name in self.eval_dataloaders:
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
)
eval_results = []
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
for eval_batch in tqdm.tqdm(
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
):
eval_outputs = await self.rollout(**eval_batch, sample_params=self.eval_sample_params)
eval_results = eval_results + [
self.evaluation_function(
eval_outputs["input_ids"][m][n],
eval_outputs[
(
"test_cases"
if self.grpo_config["reward_fn_type"] == "code"
else "gt_answer"
)
][m],
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
tags=self.response_format_tags,
)
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
]
eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results])
eval_statistics_tensor[1] += len(eval_results)
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
to_log_msg[f"eval/{eval_task_name}"] = (
eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
)
if self.producer_idx == 0:
print(
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
)
# save eval results
safe_append_to_jsonl_file(
os.path.join(
self.eval_save_dir,
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
),
eval_results,
)
if self.producer_idx == 0:
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
self.eval_mode = False
self.latest_eval_step = self.consumer_global_step
self.profiler.enter("rollout")
# breakpoint()
outputs = await self.rollout(**batch)
self.profiler.exit("rollout")
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
self.profiler.enter("calculate_reward")
if self.grpo_config["reward_fn_type"] == "code":
test_cases = []
for prompt_id in range(bs):
test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
test_cases=test_cases,
response_idx=outputs["response_idx"].view((-1, 2)),
)
else:
gt_answer = []
for prompt_id in range(bs):
gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen)
reward_model_output = self.reward_model(
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
gt_answer=gt_answer,
response_idx=outputs["response_idx"].view((-1, 2)),
)
outputs["reward"] = (
torch.tensor([value[0] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
outputs["format_acc"] = (
torch.tensor([value[1] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
outputs["ans_acc"] = (
torch.tensor([value[2] for value in reward_model_output])
.to(outputs["input_ids"].device)
.view((bs, num_gen, 1))
)
if "gt_answer" in outputs:
outputs.pop("gt_answer")
if "test_cases" in outputs:
outputs.pop("test_cases")
self.profiler.exit("calculate_reward")
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)
self.profiler.enter("send_broadcast_data")
self.sync_data(outputs)
self.profiler.exit("send_broadcast_data")
if (
(i + 1) % self.num_microbatches == 0
and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
):
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration
self.sync_model(episode, i)
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.wake_up()
# linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0:
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.9
if isinstance(self.model, BACKEND_MAP["vllm"]):
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.9
def __del__(self):
if self.producer_idx == 0:
self.wandb_run.finish()
@@ -642,65 +850,18 @@ class AsyncProducer(BaseProducer):
@ray.remote
class AsyncServer:
class AsyncSimpleProducer(BaseAsyncProducer):
"""
A async worker for inference only
Asyncronous version of the producer that uses vLLM for generation.
This class is designed to handle multiple producer actors and distribute tasks among them.
"""
def __init__(
self,
producer_idx,
num_producers,
model_config,
generate_config,
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
eval_generation_config={},
):
tokenizer_path = model_config["path"]
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
self.tokenizer.padding_side = "left"
self.microbatch_size = microbatch_size
self.producer_idx = producer_idx
self.num_producers = num_producers
assert backend == "async-vllm", f"AsyncProducer only supports async-vllm backend, got {backend}"
self.model = self.backend_cls(
model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size
)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_generation_config.update(eval_generation_config)
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
@torch.no_grad()
async def rollout(self, input_ids, attention_mask, **kwargs):
tasks = []
for prompt_id in range(input_ids.size(0)):
new_kwargs = copy.deepcopy(kwargs)
if "gt_answer" in new_kwargs:
new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id]
if "test_cases" in new_kwargs:
new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id]
tasks.append(
self.model.generate(
input_ids[prompt_id].unsqueeze(0),
attention_mask[prompt_id].unsqueeze(0),
**new_kwargs,
)
)
# print(f"Producer {self.producer_idx} running {len(tasks)} tasks")
rollouts = await asyncio.gather(*tasks)
rollouts = {
k: (
torch.cat([r[k] for r in rollouts], dim=0)
if k not in ["gt_answer", "test_cases"]
else [r[k] for r in rollouts]
)
for k in rollouts[0].keys()
}
if self.producer_idx == 0 and not self.eval_mode:
# naive rollout strategy without load balancing
rollouts = await self.generate(input_ids, attention_mask, **kwargs)
if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode:
# for agentic producer, AsyncSimpleProducer is not the main producer, so we don't log rollouts
if (
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
or self.latest_rollout_log_step == -1
@@ -721,11 +882,6 @@ class AsyncServer:
self.latest_rollout_log_step = self.consumer_global_step
return rollouts
def __del__(self):
if self.producer_idx == 0:
self.wandb_run.finish()
if hasattr(self, "rollout_log_file"):
self.rollout_log_file.close()
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
async def generate(self, input_ids, attention_mask, **kwargs):
rollouts = await super().generate(input_ids, attention_mask, **kwargs)
return rollouts

View File

@@ -110,7 +110,11 @@ if __name__ == "__main__":
# Sampling parameters
parser.add_argument(
"-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm", "async-vllm"]
"-b",
"--backend",
type=str,
default="transformers",
choices=["transformers", "vllm", "async-vllm", "async-agentic-vllm"],
)
parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
parser.add_argument(
@@ -215,7 +219,7 @@ if __name__ == "__main__":
namespace="ray-example",
runtime_env={
"env_vars": {
"RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
"TOKENIZERS_PARALLELISM": "false",
},
},
@@ -228,7 +232,7 @@ if __name__ == "__main__":
_temp_dir=args.ray_dir,
runtime_env={
"env_vars": {
"RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
"TOKENIZERS_PARALLELISM": "false",
},
},
@@ -237,7 +241,7 @@ if __name__ == "__main__":
if args.top_k is None:
if args.backend == "transformers":
args.top_k = 50
elif args.backend == "vllm" or args.backend == "async-vllm":
elif "vllm" in args.backend:
args.top_k = -1
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
@@ -265,7 +269,7 @@ if __name__ == "__main__":
)
)
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
elif args.backend == "vllm" or args.backend == "async-vllm":
elif args.backend == "vllm" or args.backend == "async-vllm" or args.backend == "async-agentic-vllm":
# os.environ["VLLM_DP_SIZE"] = str(args.producer_data_parallel_size)
inference_model_config.update(
dict(
@@ -358,6 +362,25 @@ if __name__ == "__main__":
# Default system prompt
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
if "agentic" in args.backend:
assert "vllm" in args.backend, "Agentic backend only supports async-agentic-vllm backends."
generate_config["n"] = 1 # agentic producer use AsyncProducer which processes one request a time
generate_config["max_tokens"] = (
2048 # max new tokens for each agentic step, usually smaller than max_new_tokens as agentic model will generate multiple steps
)
agentic_config = {
"model": args.model,
"model_type": "transformers",
"generate_cfg": {
"max_input_tokens": args.max_new_tokens + args.max_prompt_tokens,
},
}
agentic_config["generate_cfg"].update(
{k: v for k, v in generate_config.items() if k in ["top_k", "top_p", "temperature"]}
)
else:
agentic_config = None
launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size)
@@ -378,6 +401,7 @@ if __name__ == "__main__":
num_generations=args.num_generations,
train_model_config=train_model_config,
grpo_config=grpo_config,
agentic_config=agentic_config,
plugin_config={
"tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size,