mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
support agentic with asyncllm
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
Dataloader for sft, dpo, ppo
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
||||
@@ -423,7 +424,9 @@ class RawConversationDataset(Dataset):
|
||||
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None:
|
||||
def __init__(
|
||||
self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str, tokenize=True
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.raw_texts = []
|
||||
with jsonlines.open(input_file) as f:
|
||||
@@ -432,30 +435,50 @@ class RawConversationDataset(Dataset):
|
||||
self.tokenized_texts = [None] * len(self.raw_texts)
|
||||
self.max_length = max_length
|
||||
self.system_prompt = system_prompt
|
||||
self.tokenize = tokenize
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.raw_texts)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
if self.tokenized_texts[index] is None:
|
||||
message = self.raw_texts[index]
|
||||
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
|
||||
self.tokenized_texts[index] = dict(tokens)
|
||||
return self.tokenized_texts[index]
|
||||
if self.tokenize:
|
||||
if self.tokenized_texts[index] is None:
|
||||
message = self.raw_texts[index]
|
||||
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
|
||||
self.tokenized_texts[index] = dict(tokens)
|
||||
return self.tokenized_texts[index]
|
||||
else:
|
||||
chat = copy.deepcopy(self.raw_texts[index])
|
||||
chat["messages"] = [{"role": "system", "content": self.system_prompt}, chat["messages"]]
|
||||
return chat
|
||||
|
||||
|
||||
def collate_fn_grpo(batch):
|
||||
input_ids = [item["input_ids"] for item in batch]
|
||||
attention_mask = [item["attention_mask"] for item in batch]
|
||||
labels = [item["labels"] for item in batch]
|
||||
# Assume input_ids, attention_mask, labels are already of the same length,
|
||||
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
|
||||
input_ids = torch.stack(input_ids)
|
||||
attention_mask = torch.stack(attention_mask)
|
||||
labels = torch.stack(labels)
|
||||
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
||||
if "test_cases" in batch[0]:
|
||||
ret["test_cases"] = [item["test_cases"] for item in batch]
|
||||
if "gt_answer" in batch[0]:
|
||||
ret["gt_answer"] = [item["gt_answer"] for item in batch]
|
||||
return ret
|
||||
if "input_ids" in batch[0]:
|
||||
# tokenized format
|
||||
input_ids = [item["input_ids"] for item in batch]
|
||||
attention_mask = [item["attention_mask"] for item in batch]
|
||||
labels = [item["labels"] for item in batch]
|
||||
# Assume input_ids, attention_mask, labels are already of the same length,
|
||||
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
|
||||
input_ids = torch.stack(input_ids)
|
||||
attention_mask = torch.stack(attention_mask)
|
||||
labels = torch.stack(labels)
|
||||
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
||||
if "test_cases" in batch[0]:
|
||||
ret["test_cases"] = [item["test_cases"] for item in batch]
|
||||
if "gt_answer" in batch[0]:
|
||||
ret["gt_answer"] = [item["gt_answer"] for item in batch]
|
||||
return ret
|
||||
elif "messages" in batch[0]:
|
||||
# vllm format
|
||||
ret = {
|
||||
"messages": [item["messages"] for item in batch],
|
||||
}
|
||||
if "test_cases" in batch[0]:
|
||||
ret["test_cases"] = [item["test_cases"] for item in batch]
|
||||
if "gt_answer" in batch[0]:
|
||||
ret["gt_answer"] = [item["gt_answer"] for item in batch]
|
||||
return ret
|
||||
else:
|
||||
raise ValueError("Unsupported batch format")
|
||||
|
199
applications/ColossalChat/coati/distributed/agent/agentic.py
Normal file
199
applications/ColossalChat/coati/distributed/agent/agentic.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.distributed.agent.agentic_math_utils import TIR_SYSTEM, CustomTransformers
|
||||
from coati.distributed.producer import BaseProducer
|
||||
from qwen_agent.agents import TIRMathAgent
|
||||
from vllm import SamplingParams
|
||||
|
||||
|
||||
@ray.remote
|
||||
class AgenticProducer(BaseProducer):
|
||||
"""
|
||||
Asyncronous version of the producer that uses vLLM for generation.
|
||||
This class is designed to generate agentic response
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
async_producers,
|
||||
tokenizer_config=None,
|
||||
agentic_config=None,
|
||||
microbatch_size=1,
|
||||
backend="transformers",
|
||||
num_generations: int = 8,
|
||||
consumer_plugin_config=None,
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
grpo_config: Dict[str, Any] = None,
|
||||
eval_save_dir: str = "./eval",
|
||||
eval_generation_config={},
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
enable_profiling: bool = False,
|
||||
n_behind: int = 0,
|
||||
):
|
||||
assert microbatch_size == 1 # microbatch_size must be 1 for agentic producer
|
||||
assert batch_size == 1 # batch_size must be 1 for agentic producer
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
num_producers,
|
||||
num_consumer_procs,
|
||||
num_episodes,
|
||||
batch_size,
|
||||
train_dataset_config,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config,
|
||||
microbatch_size,
|
||||
backend,
|
||||
consumer_plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
grpo_config=grpo_config,
|
||||
eval_save_dir=eval_save_dir,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
enable_profiling=enable_profiling,
|
||||
n_behind=n_behind,
|
||||
enable_agentic=True,
|
||||
)
|
||||
self.eval_generation_config = copy.deepcopy(generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
self.eval_generation_config.update(eval_generation_config)
|
||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||
self.async_producers = async_producers
|
||||
self.num_generations = num_generations
|
||||
self.generate_config = generate_config
|
||||
self.agentic_config = model_config if not agentic_config else agentic_config
|
||||
self.agentic_config.update({"model": model_config["path"]})
|
||||
self.llm = CustomTransformers(self.agentic_config, self.producer_idx, generation_workers=self.async_producers)
|
||||
self.bot = TIRMathAgent(llm=self.llm, name=model_config["path"], system_message=TIR_SYSTEM)
|
||||
|
||||
def rollout(self, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Rollout function to generate responses for the input, for example, using LLM or agentic pipeline.
|
||||
This function should be implemented in subclasses.
|
||||
"""
|
||||
assert len(kwargs["messages"]) == 1, "Only support batch size of 1 for agentic producer"
|
||||
messages = kwargs["messages"][0]
|
||||
prompt_input_ids = self.tokenizer.apply_chat_template(
|
||||
messages, return_tensors="pt", tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# add left padding
|
||||
prompt_length = prompt_input_ids.shape[1]
|
||||
max_prompt_length = self.train_dataset_config["max_length"]
|
||||
to_pad_left = max_prompt_length - prompt_length
|
||||
rollouts = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"action_mask": [],
|
||||
"action_log_probs": [],
|
||||
"response_idx": [],
|
||||
}
|
||||
for i in range(self.num_generations):
|
||||
_messages = copy.deepcopy(messages)
|
||||
for response in self.bot.run(messages):
|
||||
continue
|
||||
_messages.extend(response)
|
||||
response_input_ids = self.tokenizer.apply_chat_template(_messages, return_tensors="pt", tokenize=True)
|
||||
# truncate if too long
|
||||
response_input_ids = response_input_ids[:, : self.grpo_config["max_length"] - to_pad_left]
|
||||
# add left right padding
|
||||
to_pad_right = self.grpo_config["max_length"] - response_input_ids.shape[1] - to_pad_left
|
||||
response_length = response_input_ids.shape[1] - prompt_length
|
||||
input_ids = torch.nn.functional.pad(
|
||||
response_input_ids, (to_pad_left, to_pad_right), "constant", value=self.tokenizer.pad_token_id
|
||||
) # [1, max_length]
|
||||
attention_mask = torch.nn.functional.pad(
|
||||
torch.ones_like(response_input_ids), (to_pad_left, to_pad_right), "constant", value=0
|
||||
) # [1, max_length]
|
||||
action_mask = torch.nn.functional.pad(
|
||||
torch.ones(size=(1, response_length)), (0, to_pad_right), "constant", value=0
|
||||
) # [1, max_length-prompt_length]
|
||||
rollouts["attention_mask"].append(attention_mask)
|
||||
rollouts["action_mask"].append(action_mask)
|
||||
rollouts["action_log_probs"].append(
|
||||
torch.ones(size=(1, self.grpo_config["max_length"] - max_prompt_length))
|
||||
) # dummy log probs
|
||||
rollouts["response_idx"].append(
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
self.train_dataset_config["max_length"],
|
||||
self.train_dataset_config["max_length"] + response_length,
|
||||
]
|
||||
]
|
||||
)
|
||||
) # [1, 2]
|
||||
rollouts["input_ids"].append(input_ids)
|
||||
# breakpoint()
|
||||
rollouts = {k: torch.cat(v, dim=0).unsqueeze(0) for k, v in rollouts.items()} # [num_generations, ...]
|
||||
rollouts["temperature"] = torch.tensor([self.agentic_config.get("temperature", 1.0)])
|
||||
if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode:
|
||||
# for agentic producer, AsyncSimpleProducer is not the main producer, so we don't log rollouts
|
||||
if (
|
||||
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
|
||||
or self.latest_rollout_log_step == -1
|
||||
):
|
||||
new_record = (
|
||||
json.dumps(
|
||||
{
|
||||
"train_step": self.consumer_global_step,
|
||||
"rollout": self.tokenizer.batch_decode(
|
||||
rollouts["input_ids"][:, 0], skip_special_tokens=True
|
||||
),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
self.rollout_log_file.write(new_record)
|
||||
self.rollout_log_file.flush()
|
||||
self.latest_rollout_log_step = self.consumer_global_step
|
||||
|
||||
if "gt_answer" in kwargs:
|
||||
rollouts["gt_answer"] = kwargs["gt_answer"]
|
||||
if "test_cases" in kwargs:
|
||||
rollouts["test_cases"] = kwargs["test_cases"]
|
||||
return rollouts
|
||||
|
||||
def sync_model(self, episode, step) -> None:
|
||||
"""
|
||||
sync model from consumer to self.async_producers
|
||||
AgenticProducer does not hold any model weights, so no need to sync model to self.async_producers
|
||||
"""
|
||||
tasks = []
|
||||
for proc in self.async_producers:
|
||||
tasks.append(proc.async_sync_model.remote(episode, step, self.num_producers))
|
||||
ray.get(tasks)
|
||||
return
|
||||
|
||||
def sync_data(self, data: Dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
sync data from self to consumer
|
||||
"""
|
||||
tasks = []
|
||||
for idx, proc in enumerate(self.async_producers):
|
||||
if idx == self.producer_idx % len(self.async_producers):
|
||||
tasks.append(proc.async_sync_data.remote(data, self.num_producers))
|
||||
else:
|
||||
tasks.append(proc.async_sync_data.remote({}, self.num_producers))
|
||||
ray.get(tasks)
|
||||
return
|
@@ -0,0 +1,170 @@
|
||||
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A TIR(tool-integrated reasoning) math agent
|
||||
```bash
|
||||
python tir_math.py
|
||||
```
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
|
||||
import ray
|
||||
from qwen_agent.agents import TIRMathAgent
|
||||
from qwen_agent.llm.base import register_llm
|
||||
from qwen_agent.llm.function_calling import BaseFnCallModel
|
||||
from qwen_agent.llm.transformers_llm import Transformers
|
||||
from qwen_agent.log import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
|
||||
|
||||
# We use the following two systems to distinguish between COT mode and TIR mode
|
||||
TIR_SYSTEM = """Please integrate natural language reasoning with programs to solve the problem above, and put your final answer within \\boxed{}."""
|
||||
COT_SYSTEM = """Please reason step by step, and put your final answer within \\boxed{}."""
|
||||
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("/mnt/nfs/share/data/model/Qwen2.5-Math-7B-Instruct", trust_remote_code=True)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __init__(self, stop_token_ids):
|
||||
self.stop_token_ids = stop_token_ids
|
||||
|
||||
def __call__(self, input_ids, scores, **kwargs):
|
||||
# Check if the last token is one of the stop tokens
|
||||
if input_ids[0, -1].item() in self.stop_token_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class LocalLLMFromGenerationWorkers:
|
||||
"""
|
||||
A class that wraps the Transformers model to support API-based text generation.
|
||||
"""
|
||||
|
||||
def __init__(self, generation_worker=None):
|
||||
self.device = "cpu"
|
||||
self.generation_worker = generation_worker
|
||||
|
||||
def generate(self, **kwargs):
|
||||
rollouts = ray.get(self.generation_worker.generate.remote(**kwargs))
|
||||
return rollouts["input_ids"]
|
||||
|
||||
|
||||
@register_llm("api_based_transformers")
|
||||
class CustomTransformers(Transformers):
|
||||
"""
|
||||
Transformers class that supports API-based text generation.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: dict, producer_idx, generation_workers=None):
|
||||
BaseFnCallModel.__init__(self, cfg) # skip the super() init of Transformers to avoid loading hf model
|
||||
############ Setup logic from Transformers.__init__ ###############
|
||||
if "model" not in cfg:
|
||||
raise ValueError("Please provide the model id or directory through `model` in cfg.")
|
||||
|
||||
try:
|
||||
from transformers import AutoConfig, AutoProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import classes from transformers. " "Please install it with `pip install -U transformers`"
|
||||
) from e
|
||||
|
||||
self.hf_config = AutoConfig.from_pretrained(cfg["model"])
|
||||
arch = self.hf_config.architectures[0]
|
||||
if len(self.hf_config.architectures) > 1:
|
||||
logger.warning(
|
||||
f"The config for the transformers model type contains more than one architecture, choosing the first: {arch}"
|
||||
)
|
||||
|
||||
# try loading a processor, if got a tokenizer, regarding the model as text-only
|
||||
processor = AutoProcessor.from_pretrained(cfg["model"])
|
||||
if isinstance(processor, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
logger.info(f"Regarding the transformers model as text-only since its processor is a tokenizer.")
|
||||
self.tokenizer = processor
|
||||
self._support_multimodal_input = False
|
||||
else:
|
||||
self.processor = processor
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
self._support_multimodal_input = True
|
||||
################################################################
|
||||
self.generation_workers = generation_workers
|
||||
self.hf_models = [
|
||||
LocalLLMFromGenerationWorkers(generation_worker=generation_worker)
|
||||
for generation_worker in generation_workers
|
||||
]
|
||||
self.producer_idx = producer_idx
|
||||
self.load_balancer_idx = producer_idx % len(self.generation_workers)
|
||||
|
||||
@property
|
||||
def hf_model(self):
|
||||
# Simple round-robin load balancing
|
||||
model = self.hf_models[self.load_balancer_idx]
|
||||
return model
|
||||
|
||||
def _chat_stream(
|
||||
self,
|
||||
messages,
|
||||
delta_stream: bool,
|
||||
generate_cfg: dict,
|
||||
):
|
||||
# overwrite streaming because streamer is not serializable
|
||||
# determine load balancer idx based on producer load, refresh every generation
|
||||
load = [ray.get(generation_worker.get_producer_load.remote()) for generation_worker in self.generation_workers]
|
||||
min_load = min(load)
|
||||
candidates = [i for i, l in enumerate(load) if l == min_load]
|
||||
# random tie break
|
||||
self.load_balancer_idx = random.choice(candidates)
|
||||
response = self._chat_no_stream(messages=messages, generate_cfg=generate_cfg)
|
||||
# if self.producer_idx == 0:
|
||||
# print(response)
|
||||
yield response
|
||||
|
||||
|
||||
def init_agent_service():
|
||||
llm_cfg = {
|
||||
# Use the OpenAI-compatible model service provided by DashScope:
|
||||
"model": "/mnt/nfs/share/data/model/Qwen2.5-Math-7B-Instruct",
|
||||
"model_type": "transformers",
|
||||
"generate_cfg": {
|
||||
# Using the API's native tool call interface
|
||||
"top_k": 1,
|
||||
},
|
||||
}
|
||||
llm = CustomTransformers(llm_cfg)
|
||||
bot = TIRMathAgent(llm=llm, name="Qwen2.5-Math", system_message=TIR_SYSTEM)
|
||||
return bot
|
||||
|
||||
|
||||
def app_tui():
|
||||
# Define the agent
|
||||
bot = init_agent_service()
|
||||
|
||||
# Chat
|
||||
messages = []
|
||||
while True:
|
||||
# Query example: 斐波那契数列前10个数字
|
||||
query = input("user question: ")
|
||||
messages.append({"role": "user", "content": query})
|
||||
response = []
|
||||
for response in bot.run(messages):
|
||||
print("bot response:", response)
|
||||
messages.extend(response)
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# # Test the TIR math agent locally
|
||||
# app_tui()
|
@@ -1,149 +0,0 @@
|
||||
"""
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 LangChain
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class LangChainChatModel(BaseChatModel):
|
||||
"""A custom chat model that echoes the first `parrot_buffer_length` characters
|
||||
of the input.
|
||||
|
||||
When contributing an implementation to LangChain, carefully document
|
||||
the model including the initialization parameters, include
|
||||
an example of how to initialize the model and include any relevant
|
||||
links to the underlying models documentation or API.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = LangChainChatModel(parrot_buffer_length=2, model="bird-brain-001")
|
||||
result = model.invoke([HumanMessage(content="hello")])
|
||||
result = model.batch([[HumanMessage(content="hello")],
|
||||
[HumanMessage(content="world")]])
|
||||
"""
|
||||
|
||||
model_name: str = Field(alias="model")
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
timeout: Optional[int] = None
|
||||
stop: Optional[List[str]] = None
|
||||
async_server_manager: Optional[Any] = None
|
||||
max_retries: int = 2
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Override the _generate method to implement the chat model logic.
|
||||
|
||||
This can be a call to an API, a call to a local model, or any other
|
||||
implementation that generates a response to the input prompt.
|
||||
|
||||
Args:
|
||||
messages: the prompt composed of a list of messages.
|
||||
stop: a list of strings on which the model should stop generating.
|
||||
If generation stops due to a stop token, the stop token itself
|
||||
SHOULD BE INCLUDED as part of the output. This is not enforced
|
||||
across models right now, but it's a good practice to follow since
|
||||
it makes it much easier to parse the output of the model
|
||||
downstream and understand why generation stopped.
|
||||
run_manager: A run manager with callbacks for the LLM.
|
||||
"""
|
||||
self.async_server_manager.generate(messages, stop, run_manager, **kwargs)
|
||||
tokens = last_message.content[: self.parrot_buffer_length]
|
||||
ct_input_tokens = sum(len(message.content) for message in messages)
|
||||
ct_output_tokens = len(tokens)
|
||||
message = AIMessage(
|
||||
content=tokens,
|
||||
additional_kwargs={}, # Used to add additional payload to the message
|
||||
response_metadata={ # Use for response metadata
|
||||
"time_in_seconds": 3,
|
||||
"model_name": self.model_name,
|
||||
},
|
||||
usage_metadata={
|
||||
"input_tokens": ct_input_tokens,
|
||||
"output_tokens": ct_output_tokens,
|
||||
"total_tokens": ct_input_tokens + ct_output_tokens,
|
||||
},
|
||||
)
|
||||
##
|
||||
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Stream the output of the model.
|
||||
|
||||
This method should be implemented if the model can generate output
|
||||
in a streaming fashion. If the model does not support streaming,
|
||||
do not implement it. In that case streaming requests will be automatically
|
||||
handled by the _generate method.
|
||||
|
||||
Args:
|
||||
messages: the prompt composed of a list of messages.
|
||||
stop: a list of strings on which the model should stop generating.
|
||||
If generation stops due to a stop token, the stop token itself
|
||||
SHOULD BE INCLUDED as part of the output. This is not enforced
|
||||
across models right now, but it's a good practice to follow since
|
||||
it makes it much easier to parse the output of the model
|
||||
downstream and understand why generation stopped.
|
||||
run_manager: A run manager with callbacks for the LLM.
|
||||
"""
|
||||
raise NotImplementedError("Streaming is not implemented for this model. Please implement the _stream method.")
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Get the type of language model used by this chat model."""
|
||||
return "echoing-chat-model-advanced"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Return a dictionary of identifying parameters.
|
||||
|
||||
This information is used by the LangChain callback system, which
|
||||
is used for tracing purposes make it possible to monitor LLMs.
|
||||
"""
|
||||
return {
|
||||
# The model name allows users to specify custom token counting
|
||||
# rules in LLM monitoring applications (e.g., in LangSmith users
|
||||
# can provide per token pricing for their model and monitor
|
||||
# costs for the given LLM.)
|
||||
"model_name": self.model_name,
|
||||
}
|
@@ -0,0 +1,126 @@
|
||||
# -------------------------------
|
||||
# 1. Define the Python tool
|
||||
# -------------------------------
|
||||
import io
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
|
||||
import requests
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
|
||||
class Capturing(list):
|
||||
"""Capture stdout prints inside exec()"""
|
||||
|
||||
def __enter__(self):
|
||||
self._stdout = sys.stdout
|
||||
sys.stdout = self._stringio = io.StringIO()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.extend(self._stringio.getvalue().splitlines())
|
||||
sys.stdout = self._stdout
|
||||
|
||||
|
||||
@tool
|
||||
def python(code: str) -> str:
|
||||
"""
|
||||
This function executes a string of Python code and returns the printed output.
|
||||
You need to print the output. Please import all libraries used in the code string.
|
||||
"""
|
||||
local_vars = {}
|
||||
with Capturing() as output:
|
||||
exec(code, {}, local_vars)
|
||||
if output == []:
|
||||
return "Error: No output printed from the code. Please ensure you print the output."
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# 2. Define a Custom API LLM wrapper
|
||||
# -------------------------------
|
||||
class CustomAPILLM:
|
||||
def __init__(self, api_url: str, api_key: str = None):
|
||||
self.api_url = api_url
|
||||
self.api_key = api_key
|
||||
|
||||
def invoke(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
messages: list of {"role": "user"/"assistant"/"system", "content": "..."}
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
payload = {
|
||||
"model": "custom-model", # depends on your API
|
||||
"messages": messages,
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Adjust according to your API response format
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# 3. Build a ReAct Agent with LangGraph
|
||||
# -------------------------------
|
||||
def build_agent():
|
||||
# Wrap custom API LLM in LangChain-compatible interface
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
class LangChainCustomLLM(BaseChatModel):
|
||||
client: CustomAPILLM = None
|
||||
|
||||
def __init__(self, client: CustomAPILLM):
|
||||
super().__init__()
|
||||
self.client = client
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
content = self.client.invoke([m.dict() for m in messages])
|
||||
return self._create_chat_result([AIMessage(content=content)])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "custom-api-llm"
|
||||
|
||||
# Init LLM
|
||||
llm_client = CustomAPILLM(api_url="http://localhost:8000/v1/chat/completions")
|
||||
llm = LangChainCustomLLM(llm_client)
|
||||
|
||||
# Tools
|
||||
tools = [python]
|
||||
|
||||
# Memory (optional)
|
||||
memory = MemorySaver()
|
||||
|
||||
# Build ReAct agent
|
||||
agent = create_react_agent(llm, tools, checkpointer=memory)
|
||||
return agent
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# 4. Run the agent on a math problem
|
||||
# -------------------------------
|
||||
if __name__ == "__main__":
|
||||
agent = build_agent()
|
||||
|
||||
# Example math question
|
||||
user_input = "What is the least common multiple of 18 and 24? Use Python if needed."
|
||||
|
||||
config = {"configurable": {"thread_id": "math-1"}}
|
||||
for event in agent.stream({"messages": [("user", user_input)]}, config):
|
||||
if "agent" in event:
|
||||
print("Agent event:", event["agent"]["messages"][-1].content)
|
||||
elif "tools" in event:
|
||||
print("Tool event:", event["tools"]["messages"][-1].content)
|
||||
|
||||
final_state = agent.get_state(config)
|
||||
print("Final Answer:", final_state["messages"][-1].content)
|
@@ -1,112 +0,0 @@
|
||||
"""
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 LangChain
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import contextlib
|
||||
import io
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
|
||||
def eval(code: str, _locals: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
# Store original keys before execution
|
||||
original_keys = set(_locals.keys())
|
||||
|
||||
try:
|
||||
with contextlib.redirect_stdout(io.StringIO()) as f:
|
||||
exec(code, builtins.__dict__, _locals)
|
||||
result = f.getvalue()
|
||||
if not result:
|
||||
result = "<code ran, no output printed to stdout>"
|
||||
except Exception as e:
|
||||
result = f"Error during execution: {repr(e)}"
|
||||
|
||||
# Determine new variables created during execution
|
||||
new_keys = set(_locals.keys()) - original_keys
|
||||
new_vars = {key: _locals[key] for key in new_keys}
|
||||
return result, new_vars
|
||||
|
||||
|
||||
def add(a: float, b: float) -> float:
|
||||
"""Add two numbers together."""
|
||||
return a + b
|
||||
|
||||
|
||||
def multiply(a: float, b: float) -> float:
|
||||
"""Multiply two numbers together."""
|
||||
return a * b
|
||||
|
||||
|
||||
def divide(a: float, b: float) -> float:
|
||||
"""Divide two numbers."""
|
||||
return a / b
|
||||
|
||||
|
||||
def subtract(a: float, b: float) -> float:
|
||||
"""Subtract two numbers."""
|
||||
return a - b
|
||||
|
||||
|
||||
def sin(a: float) -> float:
|
||||
"""Take the sine of a number."""
|
||||
return math.sin(a)
|
||||
|
||||
|
||||
def cos(a: float) -> float:
|
||||
"""Take the cosine of a number."""
|
||||
return math.cos(a)
|
||||
|
||||
|
||||
def radians(a: float) -> float:
|
||||
"""Convert degrees to radians."""
|
||||
return math.radians(a)
|
||||
|
||||
|
||||
def exponentiation(a: float, b: float) -> float:
|
||||
"""Raise one number to the power of another."""
|
||||
return a**b
|
||||
|
||||
|
||||
def sqrt(a: float) -> float:
|
||||
"""Take the square root of a number."""
|
||||
return math.sqrt(a)
|
||||
|
||||
|
||||
def ceil(a: float) -> float:
|
||||
"""Round a number up to the nearest integer."""
|
||||
return math.ceil(a)
|
||||
|
||||
|
||||
tools = [
|
||||
add,
|
||||
multiply,
|
||||
divide,
|
||||
subtract,
|
||||
sin,
|
||||
cos,
|
||||
radians,
|
||||
exponentiation,
|
||||
sqrt,
|
||||
ceil,
|
||||
]
|
@@ -150,6 +150,7 @@ class BaseConsumer:
|
||||
self.profiler.enter("sync_model")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
print(f"[C{self.rank}]: Sync model before training")
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
@@ -164,6 +165,7 @@ class BaseConsumer:
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
print(f"[C{self.rank}]: Sync model before training done")
|
||||
torch.cuda.empty_cache()
|
||||
self.profiler.exit("sync_model")
|
||||
|
||||
@@ -323,7 +325,7 @@ class BaseConsumer:
|
||||
) # for setting start index when resuming training
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")
|
||||
|
||||
# breakpoint()
|
||||
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
|
||||
episode != 0 or step >= self.n_behind
|
||||
):
|
||||
|
@@ -64,7 +64,7 @@ class AsyncInferenceBackend(BaseInferenceBackend):
|
||||
- action_mask (torch.Tensor): shape [B, N]
|
||||
where N is the number of generated tokens. And all tensors should be on CUDA.
|
||||
"""
|
||||
raise NotImplementedError("AsyncInferenceBackend does not support generate method.")
|
||||
raise NotImplementedError("Generate method must be implemented in subclass.")
|
||||
|
||||
|
||||
class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
@@ -84,6 +84,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
microbatch_size: int = 1,
|
||||
profiler=None,
|
||||
):
|
||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||
model_config.update(self.FORCE_MODEL_CONFIG)
|
||||
@@ -93,6 +94,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
self.profiler = profiler
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
@@ -158,6 +160,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
microbatch_size: int = 1,
|
||||
profiler=None,
|
||||
):
|
||||
if sgl is None:
|
||||
raise ImportError("sglang is not installed")
|
||||
@@ -223,6 +226,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
microbatch_size: int = 1,
|
||||
profiler=None,
|
||||
):
|
||||
if LLM is None:
|
||||
raise ImportError("vllm is not installed")
|
||||
@@ -323,6 +327,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
microbatch_size: int = 1,
|
||||
profiler=None,
|
||||
):
|
||||
if LLM is None:
|
||||
raise ImportError("vllm is not installed")
|
||||
@@ -332,7 +337,8 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
generate_config = generate_config.copy()
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
generate_config.update({"n": num_generations})
|
||||
if "n" not in generate_config:
|
||||
generate_config.update({"n": num_generations})
|
||||
self.generate_config = generate_config
|
||||
self.sample_params = SamplingParams(**generate_config)
|
||||
self.model_config = model_config
|
||||
@@ -340,6 +346,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
self.num_generations = num_generations
|
||||
self.queued_requests = []
|
||||
self.microbatch_size = microbatch_size
|
||||
self.profiler = profiler
|
||||
|
||||
@torch.no_grad()
|
||||
async def generate(
|
||||
@@ -351,6 +358,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
input_ids (torch.Tensor): shape [B, S], B=1
|
||||
attention_mask (torch.Tensor): shape [B, S]
|
||||
"""
|
||||
# breakpoint()
|
||||
assert input_ids.size(0) == attention_mask.size(0) == 1
|
||||
response_start_idx = input_ids.size(1)
|
||||
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
|
||||
@@ -366,6 +374,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
self.queued_requests.append(request_id) # enqueue
|
||||
# pop the first input_ids and attention_mask
|
||||
prompt_token_ids = input_ids_no_padding[0]
|
||||
self.profiler.enter(f"vllm generate {request_id}")
|
||||
outputs = self.engine.generate(
|
||||
prompt={"prompt_token_ids": prompt_token_ids}, sampling_params=sample_params, request_id=request_id
|
||||
)
|
||||
@@ -380,6 +389,7 @@ class AsyncVLLMInferenceBackend(AsyncInferenceBackend):
|
||||
assert len(output_i.logprobs) == len(output_i.token_ids)
|
||||
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
|
||||
log_probs.append(p)
|
||||
self.profiler.exit(f"vllm generate {request_id}")
|
||||
# pad them
|
||||
max_len = self.sample_params.max_tokens
|
||||
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
|
||||
|
@@ -4,13 +4,13 @@ import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
from coati.distributed.agent.agentic import AgenticProducer
|
||||
|
||||
from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer
|
||||
from .producer import AsyncProducer, SimpleProducer
|
||||
from .producer import AsyncSimpleProducer, SimpleProducer
|
||||
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
||||
Producer_MAP = {"Simple": SimpleProducer, "Async": AsyncProducer}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
@@ -42,6 +42,7 @@ def launch_distributed(
|
||||
generate_config: Dict[str, Any],
|
||||
train_model_config: Dict[str, Any],
|
||||
grpo_config: Dict[str, Any],
|
||||
agentic_config: Optional[Dict[str, Any]],
|
||||
plugin_config: Dict[str, Any],
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
inference_backend: str = "transformers",
|
||||
@@ -73,7 +74,7 @@ def launch_distributed(
|
||||
num_samples = get_jsonl_size_fast(dataset_path)
|
||||
global_inference_batch_size = inference_batch_size * num_producers
|
||||
num_update_per_episode = num_samples // global_inference_batch_size
|
||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||
num_recv_per_update = inference_batch_size // inference_microbatch_size if "async" not in inference_backend else 1
|
||||
|
||||
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
||||
wandb_group_name = str(uuid.uuid4())
|
||||
@@ -105,9 +106,12 @@ def launch_distributed(
|
||||
|
||||
producer_procs = []
|
||||
if "async" in inference_backend:
|
||||
core_producer = AsyncProducer
|
||||
core_producer = AsyncSimpleProducer
|
||||
else:
|
||||
core_producer = Producer_MAP.get("Simple", SimpleProducer)
|
||||
core_producer = SimpleProducer
|
||||
enable_agentic = "agentic" in inference_backend
|
||||
if enable_agentic:
|
||||
inference_backend = inference_backend.replace("agentic-", "")
|
||||
for i in range(num_producers):
|
||||
node_id = gpu_to_node_id[0]
|
||||
producer_ip_address = gpu_to_ip_address[0]
|
||||
@@ -125,7 +129,11 @@ def launch_distributed(
|
||||
model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
tokenizer_config=tokenizer_config,
|
||||
microbatch_size=inference_microbatch_size,
|
||||
microbatch_size=(
|
||||
inference_microbatch_size * num_generations
|
||||
if "async" in inference_backend
|
||||
else inference_microbatch_size
|
||||
),
|
||||
backend=inference_backend,
|
||||
num_generations=num_generations,
|
||||
consumer_plugin_config=plugin_config,
|
||||
@@ -138,12 +146,63 @@ def launch_distributed(
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
rollout_log_file=rollout_log_file if not enable_agentic else None,
|
||||
enable_profiling=enable_profiling,
|
||||
n_behind=n_behind,
|
||||
)
|
||||
producer_procs.append(producer)
|
||||
ray.get([p.setup.remote() for p in producer_procs])
|
||||
"""
|
||||
# test async generate
|
||||
import torch
|
||||
import asyncio
|
||||
import time
|
||||
async def test():
|
||||
res_ref = producer_procs[0].generate.remote(torch.ones((2, 10), dtype=torch.int), torch.ones((2, 10), dtype=torch.int))
|
||||
res = await res_ref
|
||||
return res
|
||||
res = asyncio.run(test())
|
||||
print(res)
|
||||
time.sleep(1000)
|
||||
"""
|
||||
|
||||
if enable_agentic:
|
||||
# when agentic is enabled, we use core_producer as inference engine and
|
||||
# AgenticProducer as the real producer
|
||||
_producer_procs = producer_procs
|
||||
producer_procs = [
|
||||
AgenticProducer.options(num_cpus=1).remote(
|
||||
producer_idx=producer_idx,
|
||||
num_producers=num_producers * train_batch_size,
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
num_episodes=num_episodes,
|
||||
batch_size=1, # batch_size must be 1 for agentic producer
|
||||
train_dataset_config=train_dataset_config,
|
||||
model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
async_producers=_producer_procs,
|
||||
tokenizer_config=tokenizer_config,
|
||||
agentic_config=agentic_config,
|
||||
microbatch_size=1, # microbatch_size must be 1 for agentic producer
|
||||
backend=inference_backend,
|
||||
num_generations=num_generations,
|
||||
consumer_plugin_config=plugin_config,
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
grpo_config=grpo_config,
|
||||
eval_save_dir=eval_save_dir,
|
||||
eval_generation_config=eval_generation_config,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
enable_profiling=enable_profiling,
|
||||
n_behind=n_behind,
|
||||
)
|
||||
for producer_idx in range(num_producers * inference_batch_size)
|
||||
]
|
||||
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
generate_config_consumer.update(
|
||||
dict(
|
||||
|
@@ -57,6 +57,7 @@ class BaseProducer:
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_log_file: str = "./rollout_log.jsonl",
|
||||
enable_profiling: bool = False,
|
||||
enable_agentic: bool = False,
|
||||
n_behind: int = 0,
|
||||
):
|
||||
self.producer_idx = producer_idx
|
||||
@@ -65,8 +66,12 @@ class BaseProducer:
|
||||
self.num_episodes = num_episodes
|
||||
self.batch_size = batch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
assert batch_size % microbatch_size == 0
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
if not isinstance(self, BaseAsyncProducer):
|
||||
assert batch_size % microbatch_size == 0, "batch_size must be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
else:
|
||||
assert microbatch_size > 0, "microbatch_size must be positive"
|
||||
self.num_microbatches = max(1, batch_size // microbatch_size)
|
||||
self.latest_eval_step = -1
|
||||
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)
|
||||
|
||||
@@ -84,13 +89,14 @@ class BaseProducer:
|
||||
self.latest_rollout_log_step = -1
|
||||
self.grpo_config = grpo_config
|
||||
self.n_behind = n_behind
|
||||
self.enable_agentic = enable_agentic
|
||||
reward_model_kwargs = {
|
||||
k: v
|
||||
for k, v in grpo_config.items()
|
||||
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"]
|
||||
}
|
||||
self.response_format_tags = grpo_config.get("response_format_tags", None)
|
||||
if producer_idx == 0:
|
||||
if producer_idx == 0 and rollout_log_file is not None:
|
||||
if os.path.exists(rollout_log_file):
|
||||
raise ValueError(
|
||||
f"Rollout log file {rollout_log_file} already exists. Please delete it or change the name."
|
||||
@@ -121,7 +127,9 @@ class BaseProducer:
|
||||
|
||||
# init dataloader
|
||||
train_dataset_path = train_dataset_config.pop("path")
|
||||
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
|
||||
self.train_dataset = RawConversationDataset(
|
||||
self.tokenizer, train_dataset_path, **train_dataset_config, tokenize=not self.enable_agentic
|
||||
)
|
||||
self.train_dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=microbatch_size,
|
||||
@@ -159,7 +167,10 @@ class BaseProducer:
|
||||
for eval_task_name in self.eval_dataset_config:
|
||||
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
|
||||
eval_dataset = RawConversationDataset(
|
||||
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
|
||||
self.tokenizer,
|
||||
eval_dataset_path,
|
||||
**eval_dataset_config[eval_task_name],
|
||||
tokenize=not self.enable_agentic,
|
||||
)
|
||||
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
|
||||
self.eval_dataloaders[eval_task_name] = DataLoader(
|
||||
@@ -207,18 +218,34 @@ class BaseProducer:
|
||||
else:
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Generate responses by running inference on the input_ids and attention_mask.
|
||||
"""
|
||||
return self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
|
||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Rollout function to generate responses for the input, for example, using LLM or agentic pipeline.
|
||||
This function should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def loop(self) -> None:
|
||||
|
||||
def sync_model(self, episode, step) -> None:
|
||||
"""
|
||||
Default implementation to sync model from consumer to producer.
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
self.profiler.enter("sync_model")
|
||||
if self.consumer_pp_size > 1:
|
||||
for pp_idx in range(self.consumer_pp_size):
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(step + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||
)
|
||||
@@ -226,6 +253,7 @@ class BaseProducer:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
print(f"[P{self.producer_idx}] Sync model episode {episode} step {(step + 1) // self.num_microbatches - 1}")
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
@@ -233,10 +261,18 @@ class BaseProducer:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
self.profiler.exit("sync_model")
|
||||
print(f"[P{self.producer_idx}] Sync initial model done.")
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def sync_data(self, data: Dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Default implementation to sync data from producer to consumer.
|
||||
"""
|
||||
ray_broadcast_tensor_dict(data, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}")
|
||||
|
||||
def loop(self) -> None:
|
||||
# breakpoint()
|
||||
self.sync_model(0, 0)
|
||||
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
|
||||
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
||||
|
||||
@@ -308,14 +344,12 @@ class BaseProducer:
|
||||
self.eval_mode = False
|
||||
self.latest_eval_step = self.consumer_global_step
|
||||
self.profiler.enter("rollout")
|
||||
if isinstance(self.model, BACKEND_MAP["async-vllm"]):
|
||||
outputs = asyncio.run(self.rollout(**batch))
|
||||
else:
|
||||
outputs = self.rollout(**batch)
|
||||
outputs = self.rollout(**batch)
|
||||
self.profiler.exit("rollout")
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
if "temperature" not in outputs:
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
|
||||
self.profiler.enter("calculate_reward")
|
||||
if self.grpo_config["reward_fn_type"] == "code":
|
||||
@@ -360,52 +394,16 @@ class BaseProducer:
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs = pre_send(outputs)
|
||||
self.profiler.enter("send_broadcast_data")
|
||||
ray_broadcast_tensor_dict(
|
||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||
)
|
||||
self.sync_data(outputs)
|
||||
self.profiler.exit("send_broadcast_data")
|
||||
if (
|
||||
(i + 1) % self.num_microbatches == 0
|
||||
and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
|
||||
and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
|
||||
):
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
):
|
||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||
# don't sync model for last iteration
|
||||
torch.cuda.empty_cache()
|
||||
self.profiler.enter("sync_model")
|
||||
if self.consumer_pp_size > 1:
|
||||
for pp_idx in range(self.consumer_pp_size):
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||
)
|
||||
if "consumer_global_step" in state_dict:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
if "consumer_global_step" in state_dict:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
self.profiler.exit("sync_model")
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
):
|
||||
self.model.llm.wake_up()
|
||||
self.sync_model(episode, i)
|
||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||
if episode <= 0:
|
||||
if episode <= 0 and hasattr(self, "model"):
|
||||
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
@@ -475,7 +473,7 @@ class SimpleProducer(BaseProducer):
|
||||
n_behind=n_behind,
|
||||
)
|
||||
self.model = self.backend_cls(
|
||||
model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size
|
||||
model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size, profiler=self.profiler
|
||||
)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
@@ -484,7 +482,7 @@ class SimpleProducer(BaseProducer):
|
||||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
rollouts = self.generate(input_ids, attention_mask, **kwargs)
|
||||
if self.producer_idx == 0 and not self.eval_mode:
|
||||
if (
|
||||
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
|
||||
@@ -516,8 +514,7 @@ class SimpleProducer(BaseProducer):
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class AsyncProducer(BaseProducer):
|
||||
class BaseAsyncProducer(BaseProducer):
|
||||
"""
|
||||
Asyncronous version of the producer that uses vLLM for generation.
|
||||
"""
|
||||
@@ -577,15 +574,39 @@ class AsyncProducer(BaseProducer):
|
||||
)
|
||||
assert backend == "async-vllm", f"AsyncProducer only supports async-vllm backend, got {backend}"
|
||||
self.model = self.backend_cls(
|
||||
model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size
|
||||
model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size, profiler=self.profiler
|
||||
)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
self.eval_generation_config.update(eval_generation_config)
|
||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||
self.ready_processes = 0
|
||||
self.condition = asyncio.Condition()
|
||||
self.data_ready_for_sending = []
|
||||
|
||||
# @torch.no_grad()
|
||||
# async def generate(self, input_ids, attention_mask, **kwargs):
|
||||
# tasks = []
|
||||
# print("input_ids:", input_ids)
|
||||
# for prompt_id in range(input_ids.size(0)):
|
||||
# new_kwargs = copy.deepcopy(kwargs)
|
||||
# if "gt_answer" in new_kwargs:
|
||||
# new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id]
|
||||
# if "test_cases" in new_kwargs:
|
||||
# new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id]
|
||||
# tasks.append(
|
||||
# self.model.generate(
|
||||
# input_ids[prompt_id].unsqueeze(0),
|
||||
# attention_mask[prompt_id].unsqueeze(0),
|
||||
# **new_kwargs,
|
||||
# )
|
||||
# )
|
||||
# rollouts = await asyncio.gather(*tasks)
|
||||
# return rollouts
|
||||
|
||||
@torch.no_grad()
|
||||
async def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
async def generate(self, input_ids, attention_mask, **kwargs):
|
||||
# naive rollout strategy
|
||||
tasks = []
|
||||
for prompt_id in range(input_ids.size(0)):
|
||||
new_kwargs = copy.deepcopy(kwargs)
|
||||
@@ -600,37 +621,224 @@ class AsyncProducer(BaseProducer):
|
||||
**new_kwargs,
|
||||
)
|
||||
)
|
||||
# print(f"Producer {self.producer_idx} running {len(tasks)} tasks")
|
||||
rollouts = await asyncio.gather(*tasks)
|
||||
rollouts = {
|
||||
k: (
|
||||
torch.cat([r[k] for r in rollouts], dim=0)
|
||||
if k not in ["gt_answer", "test_cases"]
|
||||
else [r[k] for r in rollouts]
|
||||
)
|
||||
).cpu() # CUDA tensor is not serializable by ray
|
||||
for k in rollouts[0].keys()
|
||||
}
|
||||
if self.producer_idx == 0 and not self.eval_mode:
|
||||
if (
|
||||
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
|
||||
or self.latest_rollout_log_step == -1
|
||||
):
|
||||
new_record = (
|
||||
json.dumps(
|
||||
{
|
||||
"train_step": self.consumer_global_step,
|
||||
"rollout": self.tokenizer.batch_decode(
|
||||
rollouts["input_ids"][:, 0], skip_special_tokens=True
|
||||
),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
self.rollout_log_file.write(new_record)
|
||||
self.rollout_log_file.flush()
|
||||
self.latest_rollout_log_step = self.consumer_global_step
|
||||
return rollouts
|
||||
|
||||
@torch.no_grad()
|
||||
async def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
"""
|
||||
Advanced distributed rollout strategy that dispatches the generation tasks to different DP ranks.
|
||||
Must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError("rollout must be implemented in subclasses")
|
||||
|
||||
async def get_producer_load(self):
|
||||
"""
|
||||
Get the load of each producer.
|
||||
"""
|
||||
return len(self.model.queued_requests)
|
||||
|
||||
async def async_sync_model(self, episode, step, num_processes: int = 1) -> None:
|
||||
"""
|
||||
Asyncronous version to sync model from consumer to producer.
|
||||
called by another producer, such as agentic producer.
|
||||
"""
|
||||
async with self.condition:
|
||||
self.ready_processes += 1
|
||||
# Wait until all processes are ready
|
||||
if self.ready_processes < num_processes:
|
||||
await self.condition.wait()
|
||||
|
||||
# Only one process should reset `ready_processes` and perform the sync
|
||||
if self.ready_processes == num_processes:
|
||||
self.ready_processes = 0
|
||||
self.condition.notify_all() # Notify all waiting processes
|
||||
self.sync_model(episode, step)
|
||||
|
||||
async def async_sync_data(self, data: Dict[str, torch.Tensor], num_processes: int = 1) -> None:
|
||||
# merge data dict
|
||||
async with self.condition:
|
||||
self.ready_processes += 1
|
||||
if data:
|
||||
self.data_ready_for_sending.append(data)
|
||||
|
||||
# Wait until all processes are ready
|
||||
if self.ready_processes < num_processes:
|
||||
await self.condition.wait()
|
||||
|
||||
# Only one process should reset `ready_processes` and perform the sync
|
||||
if self.ready_processes == num_processes: # wait for all producers to join
|
||||
self.ready_processes = 0
|
||||
self.condition.notify_all()
|
||||
# merge data for sending
|
||||
if len(self.data_ready_for_sending) >= 1:
|
||||
batch_rollout_data = {}
|
||||
for key in self.data_ready_for_sending[0]:
|
||||
batch_rollout_data[key] = torch.cat([d[key] for d in self.data_ready_for_sending], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
self.sync_data(batch_rollout_data)
|
||||
self.data_ready_for_sending = [] # reset
|
||||
|
||||
async def loop(self) -> None:
|
||||
self.sync_model(0, 0)
|
||||
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
|
||||
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
||||
|
||||
print(
|
||||
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
|
||||
)
|
||||
for episode in range(self.num_episodes):
|
||||
self.train_dataloader.sampler.set_epoch(episode)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
if i >= num_valid_microbatches:
|
||||
break
|
||||
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||
if (
|
||||
self.consumer_global_step - self.latest_eval_step >= self.eval_interval
|
||||
and self.consumer_global_step > self.latest_eval_step
|
||||
) or self.latest_eval_step == -1:
|
||||
to_log_msg = {}
|
||||
self.eval_mode = True
|
||||
for eval_task_name in self.eval_dataloaders:
|
||||
if self.producer_idx == 0:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Evaluate model at training step {self.consumer_global_step} on task {eval_task_name}"
|
||||
)
|
||||
eval_results = []
|
||||
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
||||
for eval_batch in tqdm.tqdm(
|
||||
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
||||
):
|
||||
eval_outputs = await self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
||||
eval_results = eval_results + [
|
||||
self.evaluation_function(
|
||||
eval_outputs["input_ids"][m][n],
|
||||
eval_outputs[
|
||||
(
|
||||
"test_cases"
|
||||
if self.grpo_config["reward_fn_type"] == "code"
|
||||
else "gt_answer"
|
||||
)
|
||||
][m],
|
||||
eval_outputs["response_idx"][m][n],
|
||||
tokenizer=self.tokenizer,
|
||||
eval_mode=True,
|
||||
tags=self.response_format_tags,
|
||||
)
|
||||
for m in range(eval_outputs["input_ids"].size(0))
|
||||
for n in range(eval_outputs["input_ids"].size(1))
|
||||
]
|
||||
eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results])
|
||||
eval_statistics_tensor[1] += len(eval_results)
|
||||
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
|
||||
to_log_msg[f"eval/{eval_task_name}"] = (
|
||||
eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
|
||||
)
|
||||
if self.producer_idx == 0:
|
||||
print(
|
||||
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
|
||||
)
|
||||
# save eval results
|
||||
safe_append_to_jsonl_file(
|
||||
os.path.join(
|
||||
self.eval_save_dir,
|
||||
f"{eval_task_name}_training_step_{self.consumer_global_step}.jsonl",
|
||||
),
|
||||
eval_results,
|
||||
)
|
||||
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||
self.eval_mode = False
|
||||
self.latest_eval_step = self.consumer_global_step
|
||||
self.profiler.enter("rollout")
|
||||
# breakpoint()
|
||||
outputs = await self.rollout(**batch)
|
||||
self.profiler.exit("rollout")
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
|
||||
self.profiler.enter("calculate_reward")
|
||||
if self.grpo_config["reward_fn_type"] == "code":
|
||||
test_cases = []
|
||||
for prompt_id in range(bs):
|
||||
test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
|
||||
reward_model_output = self.reward_model(
|
||||
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
|
||||
test_cases=test_cases,
|
||||
response_idx=outputs["response_idx"].view((-1, 2)),
|
||||
)
|
||||
else:
|
||||
gt_answer = []
|
||||
for prompt_id in range(bs):
|
||||
gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen)
|
||||
reward_model_output = self.reward_model(
|
||||
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
|
||||
gt_answer=gt_answer,
|
||||
response_idx=outputs["response_idx"].view((-1, 2)),
|
||||
)
|
||||
outputs["reward"] = (
|
||||
torch.tensor([value[0] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
outputs["format_acc"] = (
|
||||
torch.tensor([value[1] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
outputs["ans_acc"] = (
|
||||
torch.tensor([value[2] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
if "gt_answer" in outputs:
|
||||
outputs.pop("gt_answer")
|
||||
if "test_cases" in outputs:
|
||||
outputs.pop("test_cases")
|
||||
self.profiler.exit("calculate_reward")
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs = pre_send(outputs)
|
||||
self.profiler.enter("send_broadcast_data")
|
||||
self.sync_data(outputs)
|
||||
self.profiler.exit("send_broadcast_data")
|
||||
if (
|
||||
(i + 1) % self.num_microbatches == 0
|
||||
and (episode != self.num_episodes - 1 or i != num_valid_microbatches - 1)
|
||||
and (episode != 0 or (i + 1) > self.n_behind * self.num_microbatches)
|
||||
):
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
):
|
||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||
# don't sync model for last iteration
|
||||
self.sync_model(episode, i)
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
):
|
||||
self.model.llm.wake_up()
|
||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||
if episode <= 0:
|
||||
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]):
|
||||
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
|
||||
def __del__(self):
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.finish()
|
||||
@@ -642,65 +850,18 @@ class AsyncProducer(BaseProducer):
|
||||
|
||||
|
||||
@ray.remote
|
||||
class AsyncServer:
|
||||
class AsyncSimpleProducer(BaseAsyncProducer):
|
||||
"""
|
||||
A async worker for inference only
|
||||
Asyncronous version of the producer that uses vLLM for generation.
|
||||
This class is designed to handle multiple producer actors and distribute tasks among them.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
producer_idx,
|
||||
num_producers,
|
||||
model_config,
|
||||
generate_config,
|
||||
tokenizer_config=None,
|
||||
microbatch_size=1,
|
||||
backend="transformers",
|
||||
num_generations: int = 8,
|
||||
eval_generation_config={},
|
||||
):
|
||||
tokenizer_path = model_config["path"]
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.microbatch_size = microbatch_size
|
||||
self.producer_idx = producer_idx
|
||||
self.num_producers = num_producers
|
||||
assert backend == "async-vllm", f"AsyncProducer only supports async-vllm backend, got {backend}"
|
||||
self.model = self.backend_cls(
|
||||
model_config, generate_config, self.tokenizer, num_generations, self.microbatch_size
|
||||
)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
self.eval_generation_config.update(eval_generation_config)
|
||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||
|
||||
@torch.no_grad()
|
||||
async def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
tasks = []
|
||||
for prompt_id in range(input_ids.size(0)):
|
||||
new_kwargs = copy.deepcopy(kwargs)
|
||||
if "gt_answer" in new_kwargs:
|
||||
new_kwargs["gt_answer"] = new_kwargs["gt_answer"][prompt_id]
|
||||
if "test_cases" in new_kwargs:
|
||||
new_kwargs["test_cases"] = new_kwargs["test_cases"][prompt_id]
|
||||
tasks.append(
|
||||
self.model.generate(
|
||||
input_ids[prompt_id].unsqueeze(0),
|
||||
attention_mask[prompt_id].unsqueeze(0),
|
||||
**new_kwargs,
|
||||
)
|
||||
)
|
||||
# print(f"Producer {self.producer_idx} running {len(tasks)} tasks")
|
||||
rollouts = await asyncio.gather(*tasks)
|
||||
rollouts = {
|
||||
k: (
|
||||
torch.cat([r[k] for r in rollouts], dim=0)
|
||||
if k not in ["gt_answer", "test_cases"]
|
||||
else [r[k] for r in rollouts]
|
||||
)
|
||||
for k in rollouts[0].keys()
|
||||
}
|
||||
if self.producer_idx == 0 and not self.eval_mode:
|
||||
# naive rollout strategy without load balancing
|
||||
rollouts = await self.generate(input_ids, attention_mask, **kwargs)
|
||||
if hasattr(self, "rollout_log_file") and self.producer_idx == 0 and not self.eval_mode:
|
||||
# for agentic producer, AsyncSimpleProducer is not the main producer, so we don't log rollouts
|
||||
if (
|
||||
self.consumer_global_step - self.latest_rollout_log_step >= self.log_rollout_interval
|
||||
or self.latest_rollout_log_step == -1
|
||||
@@ -721,11 +882,6 @@ class AsyncServer:
|
||||
self.latest_rollout_log_step = self.consumer_global_step
|
||||
return rollouts
|
||||
|
||||
def __del__(self):
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.finish()
|
||||
if hasattr(self, "rollout_log_file"):
|
||||
self.rollout_log_file.close()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.model.load_state_dict(state_dict)
|
||||
async def generate(self, input_ids, attention_mask, **kwargs):
|
||||
rollouts = await super().generate(input_ids, attention_mask, **kwargs)
|
||||
return rollouts
|
||||
|
@@ -110,7 +110,11 @@ if __name__ == "__main__":
|
||||
|
||||
# Sampling parameters
|
||||
parser.add_argument(
|
||||
"-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm", "async-vllm"]
|
||||
"-b",
|
||||
"--backend",
|
||||
type=str,
|
||||
default="transformers",
|
||||
choices=["transformers", "vllm", "async-vllm", "async-agentic-vllm"],
|
||||
)
|
||||
parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
|
||||
parser.add_argument(
|
||||
@@ -215,7 +219,7 @@ if __name__ == "__main__":
|
||||
namespace="ray-example",
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
"RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false",
|
||||
},
|
||||
},
|
||||
@@ -228,7 +232,7 @@ if __name__ == "__main__":
|
||||
_temp_dir=args.ray_dir,
|
||||
runtime_env={
|
||||
"env_vars": {
|
||||
"RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
|
||||
"TOKENIZERS_PARALLELISM": "false",
|
||||
},
|
||||
},
|
||||
@@ -237,7 +241,7 @@ if __name__ == "__main__":
|
||||
if args.top_k is None:
|
||||
if args.backend == "transformers":
|
||||
args.top_k = 50
|
||||
elif args.backend == "vllm" or args.backend == "async-vllm":
|
||||
elif "vllm" in args.backend:
|
||||
args.top_k = -1
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
|
||||
@@ -265,7 +269,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
)
|
||||
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||
elif args.backend == "vllm" or args.backend == "async-vllm":
|
||||
elif args.backend == "vllm" or args.backend == "async-vllm" or args.backend == "async-agentic-vllm":
|
||||
# os.environ["VLLM_DP_SIZE"] = str(args.producer_data_parallel_size)
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
@@ -358,6 +362,25 @@ if __name__ == "__main__":
|
||||
# Default system prompt
|
||||
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
|
||||
|
||||
if "agentic" in args.backend:
|
||||
assert "vllm" in args.backend, "Agentic backend only supports async-agentic-vllm backends."
|
||||
generate_config["n"] = 1 # agentic producer use AsyncProducer which processes one request a time
|
||||
generate_config["max_tokens"] = (
|
||||
2048 # max new tokens for each agentic step, usually smaller than max_new_tokens as agentic model will generate multiple steps
|
||||
)
|
||||
agentic_config = {
|
||||
"model": args.model,
|
||||
"model_type": "transformers",
|
||||
"generate_cfg": {
|
||||
"max_input_tokens": args.max_new_tokens + args.max_prompt_tokens,
|
||||
},
|
||||
}
|
||||
agentic_config["generate_cfg"].update(
|
||||
{k: v for k, v in generate_config.items() if k in ["top_k", "top_p", "temperature"]}
|
||||
)
|
||||
else:
|
||||
agentic_config = None
|
||||
|
||||
launch_distributed(
|
||||
num_producers=args.num_inferencer,
|
||||
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size)
|
||||
@@ -378,6 +401,7 @@ if __name__ == "__main__":
|
||||
num_generations=args.num_generations,
|
||||
train_model_config=train_model_config,
|
||||
grpo_config=grpo_config,
|
||||
agentic_config=agentic_config,
|
||||
plugin_config={
|
||||
"tp_size": args.tensor_parallel_size,
|
||||
"pp_size": args.pipeline_parallel_size,
|
||||
|
Reference in New Issue
Block a user