mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-16 07:03:40 +00:00
[feature] fit RL style generation (#6213)
* [feature] fit rl style generation * [doc] add docstr * [doc] add docstr
This commit is contained in:
parent
43c9b5fb44
commit
de282dd694
@ -168,6 +168,11 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
||||||
|
|
||||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||||
|
labels = kwargs["input_ids"].clone()
|
||||||
|
labels[kwargs["attention_mask"] == 0] = -100
|
||||||
|
kwargs["labels"] = labels
|
||||||
|
assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape
|
||||||
|
|
||||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
|
|
||||||
ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer)
|
ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer)
|
||||||
|
@ -2,10 +2,12 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
from .utils import log_probs_from_logits, update_by_default
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -22,37 +24,73 @@ class BaseInferenceBackend:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
pass
|
"""Generate new tokens given input_ids and attention_mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (torch.Tensor): shape [B, S]
|
||||||
|
attention_mask (torch.Tensor): shape [B, S]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, torch.Tensor]: containing the
|
||||||
|
- input_ids (torch.Tensor): shape [B, S+N]
|
||||||
|
- attention_mask (torch.Tensor): shape [B, S+N]
|
||||||
|
- action_log_probs (torch.Tensor): shape [B, N]
|
||||||
|
- action_mask (torch.Tensor): shape [B, N]
|
||||||
|
where N is the number of generated tokens. And all tensors should be on CUDA.
|
||||||
|
"""
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TransformersInferenceBackend(BaseInferenceBackend):
|
class TransformersInferenceBackend(BaseInferenceBackend):
|
||||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
DEFAULT_MODEL_CONFIG = dict(
|
||||||
path = model_config.pop("path")
|
trust_remote_code=True,
|
||||||
defaut_config = dict(
|
torch_dtype=torch.bfloat16,
|
||||||
trust_remote_code=True,
|
)
|
||||||
torch_dtype=torch.bfloat16,
|
FORCE_MODEL_CONFIG = dict(
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
)
|
)
|
||||||
defaut_config.update(model_config)
|
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
|
||||||
self.model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(path, **defaut_config)
|
|
||||||
self.generate_config = generate_config
|
|
||||||
|
|
||||||
|
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||||
|
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||||
|
model_config.update(self.FORCE_MODEL_CONFIG)
|
||||||
|
path = model_config.pop("path")
|
||||||
|
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
self.generate_config = generate_config.copy()
|
||||||
|
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
input_ids = input_ids.to(get_current_device())
|
input_ids = input_ids.to(get_current_device())
|
||||||
attention_mask = attention_mask.to(get_current_device())
|
attention_mask = attention_mask.to(get_current_device())
|
||||||
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
|
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
|
||||||
input_len = input_ids.shape[-1]
|
input_len = input_ids.shape[-1]
|
||||||
labels = out.clone()
|
new_token_ids = out.sequences[:, input_len:]
|
||||||
labels[..., :input_len] = -100
|
# get log probs
|
||||||
attention_mask = F.pad(attention_mask, (0, out.shape[-1] - input_len), value=1)
|
assert new_token_ids.shape[-1] == len(out.logits)
|
||||||
attention_mask = attention_mask.expand_as(labels)
|
action_log_probs = []
|
||||||
|
for i, logits in enumerate(out.logits):
|
||||||
|
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
|
||||||
|
action_log_probs = torch.cat(action_log_probs, dim=1)
|
||||||
|
# get action mask
|
||||||
|
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
|
||||||
|
if self.tokenizer.eos_token_id is not None:
|
||||||
|
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
|
||||||
|
action_mask[indices[0], indices[1] + 1 :] = 0
|
||||||
|
|
||||||
|
if attention_mask.size(0) != action_mask.size(0):
|
||||||
|
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||||
|
attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0)
|
||||||
|
|
||||||
|
attention_mask = torch.cat((attention_mask, action_mask), dim=1)
|
||||||
data = {
|
data = {
|
||||||
"input_ids": out,
|
"input_ids": out.sequences,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"labels": labels,
|
"action_log_probs": action_log_probs,
|
||||||
|
"action_mask": action_mask,
|
||||||
}
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@ -75,6 +113,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.config = AutoConfig.from_pretrained(path)
|
self.config = AutoConfig.from_pretrained(path)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
outputs = self.llm.generate(input_ids=input_ids.tolist(), sampling_params=self.generate_config)
|
outputs = self.llm.generate(input_ids=input_ids.tolist(), sampling_params=self.generate_config)
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
@ -110,45 +149,66 @@ class SGLangInferenceBackend(BaseInferenceBackend):
|
|||||||
|
|
||||||
|
|
||||||
class VLLMInferenceBackend(BaseInferenceBackend):
|
class VLLMInferenceBackend(BaseInferenceBackend):
|
||||||
|
DEFAULT_MODEL_CONFIG = dict(
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
FORCE_GENERATE_CONFIG = dict(
|
||||||
|
logprobs=0,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||||
if LLM is None:
|
if LLM is None:
|
||||||
raise ImportError("vllm is not installed")
|
raise ImportError("vllm is not installed")
|
||||||
|
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
defaut_config = dict(
|
self.llm = LLM(path, **model_config)
|
||||||
trust_remote_code=True,
|
generate_config = generate_config.copy()
|
||||||
# skip_tokenizer_init=True,
|
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
)
|
self.generate_config = SamplingParams(**generate_config)
|
||||||
defaut_config.update(model_config)
|
|
||||||
self.llm = LLM(path, **defaut_config)
|
|
||||||
self.generate_config = SamplingParams(**generate_config, stop_token_ids=[tokenizer.eos_token_id])
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.config = AutoConfig.from_pretrained(path)
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
outputs = self.llm.generate(
|
outputs = self.llm.generate(
|
||||||
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
|
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
|
||||||
)
|
)
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
out_len = []
|
out_len = []
|
||||||
|
log_probs = []
|
||||||
for out in outputs:
|
for out in outputs:
|
||||||
out_tokens.append(list(out.outputs[0].token_ids))
|
for output_i in out.outputs:
|
||||||
out_len.append(len(out.outputs[0].token_ids))
|
out_len.append(len(output_i.token_ids))
|
||||||
|
out_tokens.append(list(output_i.token_ids))
|
||||||
|
assert len(output_i.logprobs) == len(output_i.token_ids)
|
||||||
|
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
|
||||||
|
log_probs.append(p)
|
||||||
|
|
||||||
|
# pad them
|
||||||
max_len = max(out_len)
|
max_len = max(out_len)
|
||||||
input_len = input_ids.shape[-1]
|
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
|
||||||
attention_mask = F.pad(attention_mask, (0, max_len), value=1)
|
|
||||||
for i in range(len(out_tokens)):
|
for i, new_token_ids in enumerate(out_tokens):
|
||||||
out_tokens[i] = out_tokens[i] + [self.tokenizer.pad_token_id] * (max_len - out_len[i])
|
pad_len = max_len - out_len[i]
|
||||||
attention_mask[i, input_len + out_len[i] :] = 0
|
out_tokens[i] = new_token_ids + [self.tokenizer.pad_token_id] * pad_len
|
||||||
out = torch.tensor(out_tokens)
|
log_probs[i] = log_probs[i] + [0.0] * pad_len
|
||||||
out = torch.cat((input_ids, out), dim=1)
|
action_mask[i, out_len[i] :] = 0
|
||||||
labels = out.clone()
|
|
||||||
labels[..., :input_len] = -100
|
out_tokens = torch.tensor(out_tokens)
|
||||||
for i in range(len(out_len)):
|
log_probs = torch.tensor(log_probs)
|
||||||
labels[i, input_len + out_len[i] :] = -100
|
if attention_mask.size(0) != action_mask.size(0):
|
||||||
|
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||||
|
num_returns = action_mask.size(0) // attention_mask.size(0)
|
||||||
|
attention_mask = attention_mask.repeat_interleave(num_returns, dim=0)
|
||||||
|
input_ids = input_ids.repeat_interleave(num_returns, dim=0)
|
||||||
|
|
||||||
|
out_tokens = torch.cat((input_ids, out_tokens), dim=1)
|
||||||
|
attention_mask = torch.cat((attention_mask, action_mask), dim=1)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"input_ids": out,
|
"input_ids": out_tokens,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"labels": labels,
|
"action_log_probs": log_probs,
|
||||||
|
"action_mask": action_mask,
|
||||||
}
|
}
|
||||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||||
return data
|
return data
|
||||||
@ -159,6 +219,6 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
|
|
||||||
BACKEND_MAP = {
|
BACKEND_MAP = {
|
||||||
"transformers": TransformersInferenceBackend,
|
"transformers": TransformersInferenceBackend,
|
||||||
"sglang": SGLangInferenceBackend,
|
# "sglang": SGLangInferenceBackend, # sglang backend will stuck the process due to unknown reason
|
||||||
"vllm": VLLMInferenceBackend,
|
"vllm": VLLMInferenceBackend,
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -25,16 +25,42 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
|
|||||||
|
|
||||||
|
|
||||||
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
# compress attention_mask to save bandwidth
|
# compress mask to save bandwidth
|
||||||
if "attention_mask" in batch:
|
if "attention_mask" in batch:
|
||||||
attention_mask = batch["attention_mask"]
|
batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
|
||||||
batch["attention_mask"] = attention_mask.to(torch.bool)
|
if "action_mask" in batch:
|
||||||
|
batch["action_mask"] = batch["action_mask"].to(torch.bool)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def post_recv(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def post_recv(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
# decompress attention_mask
|
# decompress mask
|
||||||
if "attention_mask" in batch:
|
if "attention_mask" in batch:
|
||||||
attention_mask = batch["attention_mask"]
|
batch["attention_mask"] = batch["attention_mask"].to(torch.int)
|
||||||
batch["attention_mask"] = attention_mask.to(torch.int)
|
if "action_mask" in batch:
|
||||||
|
batch["action_mask"] = batch["action_mask"].to(torch.int)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def update_by_default(data: Dict[str, Any], default: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
data = data.copy()
|
||||||
|
for k, v in default.items():
|
||||||
|
if k not in data:
|
||||||
|
data[k] = v
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute the log probabilities from logits for the given labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (torch.Tensor): The input logits.
|
||||||
|
labels (torch.Tensor): The target labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The log probabilities corresponding to the labels.
|
||||||
|
"""
|
||||||
|
log_probs = torch.log_softmax(logits, dim=-1)
|
||||||
|
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||||
|
return per_label_logps.squeeze(-1)
|
||||||
|
Loading…
Reference in New Issue
Block a user