mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
spot a possible bug
This commit is contained in:
@@ -80,8 +80,9 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
if self.num_generations > 1:
|
||||
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
|
||||
generate_config = kwargs.get("generate_config", self.generate_config)
|
||||
out = self.model.generate(
|
||||
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
|
||||
input_ids, attention_mask=attention_mask, **kwargs, **generate_config, tokenizer=self.tokenizer
|
||||
)
|
||||
input_len = input_ids.shape[-1]
|
||||
new_token_ids = out.sequences[:, input_len:]
|
||||
|
Reference in New Issue
Block a user