mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[Inference] User Experience: update the logic of default tokenizer and generation config. (#5337)
* add * fix * fix * pause * fix * fix pytest * align * fix * license * fix * fix * fix readme * fix some bugs * remove tokenizer config
This commit is contained in:
@@ -33,7 +33,7 @@ class InferenceEngine:
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
|
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||
@@ -42,19 +42,20 @@ class InferenceEngine:
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: Optional["InferenceConfig"] = None,
|
||||
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
|
||||
inference_config: InferenceConfig,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy = None,
|
||||
) -> None:
|
||||
assert inference_config, "Please provide inference_config."
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
self.device = torch.device("cuda")
|
||||
self.dtype = inference_config.dtype
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
model = model.eval()
|
||||
model.to(self.dtype)
|
||||
|
||||
@@ -80,6 +81,8 @@ class InferenceEngine:
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
|
||||
# DISCUSS maybe move this into batch info?
|
||||
|
||||
self.counter = count()
|
||||
|
||||
def _verify_config(self) -> None:
|
||||
@@ -137,7 +140,7 @@ class InferenceEngine:
|
||||
self,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
generation_config: GenerationConfig = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
@@ -158,6 +161,10 @@ class InferenceEngine:
|
||||
output_seqs_list = []
|
||||
output_tokens_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_seqs_list += self.step()
|
||||
|
||||
@@ -285,8 +292,8 @@ class InferenceEngine:
|
||||
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
||||
|
Reference in New Issue
Block a user