From de282dd694ebd8c226017ff2bd68cff56a86e820 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 21 Feb 2025 17:28:19 +0800 Subject: [PATCH] [feature] fit RL style generation (#6213) * [feature] fit rl style generation * [doc] add docstr * [doc] add docstr --- .../coati/distributed/consumer.py | 5 + .../coati/distributed/inference_backend.py | 144 +++++++++++++----- .../ColossalChat/coati/distributed/utils.py | 40 ++++- 3 files changed, 140 insertions(+), 49 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 61417f7e6..84a69979f 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -168,6 +168,11 @@ class SimpleConsumer(BaseConsumer): self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer) def step(self, step_idx: int, **kwargs) -> Optional[float]: + labels = kwargs["input_ids"].clone() + labels[kwargs["attention_mask"] == 0] = -100 + kwargs["labels"] = labels + assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape + need_update = (step_idx + 1) % self.num_microbatches == 0 ctx = nullcontext() if need_update else self.booster.no_sync(self.model, self.optimizer) diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index d40808ab4..95b7d1e80 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -2,10 +2,12 @@ from typing import Any, Dict import torch import torch.nn.functional as F -from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer from colossalai.utils import get_current_device +from .utils import log_probs_from_logits, update_by_default + try: import sglang as sgl except ImportError: @@ -22,37 +24,73 @@ class BaseInferenceBackend: pass def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: - pass + """Generate new tokens given input_ids and attention_mask. + + Args: + input_ids (torch.Tensor): shape [B, S] + attention_mask (torch.Tensor): shape [B, S] + + Returns: + Dict[str, torch.Tensor]: containing the + - input_ids (torch.Tensor): shape [B, S+N] + - attention_mask (torch.Tensor): shape [B, S+N] + - action_log_probs (torch.Tensor): shape [B, N] + - action_mask (torch.Tensor): shape [B, N] + where N is the number of generated tokens. And all tensors should be on CUDA. + """ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: pass class TransformersInferenceBackend(BaseInferenceBackend): - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): - path = model_config.pop("path") - defaut_config = dict( - trust_remote_code=True, - torch_dtype=torch.bfloat16, - device_map="auto", - ) - defaut_config.update(model_config) - self.model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(path, **defaut_config) - self.generate_config = generate_config + DEFAULT_MODEL_CONFIG = dict( + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + FORCE_MODEL_CONFIG = dict( + device_map="auto", + ) + FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True) + def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) + model_config.update(self.FORCE_MODEL_CONFIG) + path = model_config.pop("path") + self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.generate_config = generate_config.copy() + self.generate_config.update(self.FORCE_GENERATE_CONFIG) + self.tokenizer = tokenizer + + @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: input_ids = input_ids.to(get_current_device()) attention_mask = attention_mask.to(get_current_device()) out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config) input_len = input_ids.shape[-1] - labels = out.clone() - labels[..., :input_len] = -100 - attention_mask = F.pad(attention_mask, (0, out.shape[-1] - input_len), value=1) - attention_mask = attention_mask.expand_as(labels) + new_token_ids = out.sequences[:, input_len:] + # get log probs + assert new_token_ids.shape[-1] == len(out.logits) + action_log_probs = [] + for i, logits in enumerate(out.logits): + action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1])) + action_log_probs = torch.cat(action_log_probs, dim=1) + # get action mask + action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype) + if self.tokenizer.eos_token_id is not None: + for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id): + action_mask[indices[0], indices[1] + 1 :] = 0 + + if attention_mask.size(0) != action_mask.size(0): + assert action_mask.size(0) % attention_mask.size(0) == 0 + attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0) + + attention_mask = torch.cat((attention_mask, action_mask), dim=1) data = { - "input_ids": out, + "input_ids": out.sequences, "attention_mask": attention_mask, - "labels": labels, + "action_log_probs": action_log_probs, + "action_mask": action_mask, } return data @@ -75,6 +113,7 @@ class SGLangInferenceBackend(BaseInferenceBackend): self.tokenizer = tokenizer self.config = AutoConfig.from_pretrained(path) + @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: outputs = self.llm.generate(input_ids=input_ids.tolist(), sampling_params=self.generate_config) out_tokens = [] @@ -110,45 +149,66 @@ class SGLangInferenceBackend(BaseInferenceBackend): class VLLMInferenceBackend(BaseInferenceBackend): + DEFAULT_MODEL_CONFIG = dict( + trust_remote_code=True, + ) + FORCE_GENERATE_CONFIG = dict( + logprobs=0, + ) + def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): if LLM is None: raise ImportError("vllm is not installed") + model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) path = model_config.pop("path") - defaut_config = dict( - trust_remote_code=True, - # skip_tokenizer_init=True, - ) - defaut_config.update(model_config) - self.llm = LLM(path, **defaut_config) - self.generate_config = SamplingParams(**generate_config, stop_token_ids=[tokenizer.eos_token_id]) + self.llm = LLM(path, **model_config) + generate_config = generate_config.copy() + generate_config.update(self.FORCE_GENERATE_CONFIG) + self.generate_config = SamplingParams(**generate_config) self.tokenizer = tokenizer - self.config = AutoConfig.from_pretrained(path) + @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: outputs = self.llm.generate( prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False ) out_tokens = [] out_len = [] + log_probs = [] for out in outputs: - out_tokens.append(list(out.outputs[0].token_ids)) - out_len.append(len(out.outputs[0].token_ids)) + for output_i in out.outputs: + out_len.append(len(output_i.token_ids)) + out_tokens.append(list(output_i.token_ids)) + assert len(output_i.logprobs) == len(output_i.token_ids) + p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)] + log_probs.append(p) + + # pad them max_len = max(out_len) - input_len = input_ids.shape[-1] - attention_mask = F.pad(attention_mask, (0, max_len), value=1) - for i in range(len(out_tokens)): - out_tokens[i] = out_tokens[i] + [self.tokenizer.pad_token_id] * (max_len - out_len[i]) - attention_mask[i, input_len + out_len[i] :] = 0 - out = torch.tensor(out_tokens) - out = torch.cat((input_ids, out), dim=1) - labels = out.clone() - labels[..., :input_len] = -100 - for i in range(len(out_len)): - labels[i, input_len + out_len[i] :] = -100 + action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype) + + for i, new_token_ids in enumerate(out_tokens): + pad_len = max_len - out_len[i] + out_tokens[i] = new_token_ids + [self.tokenizer.pad_token_id] * pad_len + log_probs[i] = log_probs[i] + [0.0] * pad_len + action_mask[i, out_len[i] :] = 0 + + out_tokens = torch.tensor(out_tokens) + log_probs = torch.tensor(log_probs) + if attention_mask.size(0) != action_mask.size(0): + assert action_mask.size(0) % attention_mask.size(0) == 0 + num_returns = action_mask.size(0) // attention_mask.size(0) + attention_mask = attention_mask.repeat_interleave(num_returns, dim=0) + input_ids = input_ids.repeat_interleave(num_returns, dim=0) + + out_tokens = torch.cat((input_ids, out_tokens), dim=1) + attention_mask = torch.cat((attention_mask, action_mask), dim=1) + data = { - "input_ids": out, + "input_ids": out_tokens, "attention_mask": attention_mask, - "labels": labels, + "action_log_probs": log_probs, + "action_mask": action_mask, } data = {k: v.to(get_current_device()) for k, v in data.items()} return data @@ -159,6 +219,6 @@ class VLLMInferenceBackend(BaseInferenceBackend): BACKEND_MAP = { "transformers": TransformersInferenceBackend, - "sglang": SGLangInferenceBackend, + # "sglang": SGLangInferenceBackend, # sglang backend will stuck the process due to unknown reason "vllm": VLLMInferenceBackend, } diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 2f3267a1f..533a5ffb2 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Any, Dict, List import torch @@ -25,16 +25,42 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - # compress attention_mask to save bandwidth + # compress mask to save bandwidth if "attention_mask" in batch: - attention_mask = batch["attention_mask"] - batch["attention_mask"] = attention_mask.to(torch.bool) + batch["attention_mask"] = batch["attention_mask"].to(torch.bool) + if "action_mask" in batch: + batch["action_mask"] = batch["action_mask"].to(torch.bool) return batch def post_recv(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - # decompress attention_mask + # decompress mask if "attention_mask" in batch: - attention_mask = batch["attention_mask"] - batch["attention_mask"] = attention_mask.to(torch.int) + batch["attention_mask"] = batch["attention_mask"].to(torch.int) + if "action_mask" in batch: + batch["action_mask"] = batch["action_mask"].to(torch.int) return batch + + +def update_by_default(data: Dict[str, Any], default: Dict[str, Any]) -> Dict[str, Any]: + data = data.copy() + for k, v in default.items(): + if k not in data: + data[k] = v + return data + + +def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + Compute the log probabilities from logits for the given labels. + + Args: + logits (torch.Tensor): The input logits. + labels (torch.Tensor): The target labels. + + Returns: + torch.Tensor: The log probabilities corresponding to the labels. + """ + log_probs = torch.log_softmax(logits, dim=-1) + per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) + return per_label_logps.squeeze(-1)