mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
polish
This commit is contained in:
@@ -154,6 +154,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
)
|
||||
FORCE_GENERATE_CONFIG = dict(
|
||||
logprobs=0,
|
||||
n=4,
|
||||
)
|
||||
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
@@ -166,19 +167,24 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
micro_batch_size = input_ids.size(0)
|
||||
response_start_idx = input_ids.size(1)
|
||||
outputs = self.llm.generate(
|
||||
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
|
||||
)
|
||||
out_tokens = []
|
||||
out_len = []
|
||||
log_probs = []
|
||||
response_idx = []
|
||||
for out in outputs:
|
||||
for output_i in out.outputs:
|
||||
out_len.append(len(output_i.token_ids))
|
||||
out_tokens.append(list(output_i.token_ids))
|
||||
response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
|
||||
assert len(output_i.logprobs) == len(output_i.token_ids)
|
||||
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
|
||||
log_probs.append(p)
|
||||
@@ -195,6 +201,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
|
||||
out_tokens = torch.tensor(out_tokens)
|
||||
log_probs = torch.tensor(log_probs)
|
||||
response_idx = torch.tensor(response_idx)
|
||||
|
||||
if attention_mask.size(0) != action_mask.size(0):
|
||||
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||
num_returns = action_mask.size(0) // attention_mask.size(0)
|
||||
@@ -209,9 +217,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
"attention_mask": attention_mask,
|
||||
"action_log_probs": log_probs,
|
||||
"action_mask": action_mask,
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
|
||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
||||
|
||||
if "gt_answer" in kwargs:
|
||||
data["gt_answer"] = kwargs["gt_answer"]
|
||||
# repeat gt_answer for each prompt.
|
||||
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
|
Reference in New Issue
Block a user