[Inference] User Experience: update the logic of default tokenizer and generation config. (#5337)

* add

* fix

* fix

* pause

* fix

* fix pytest

* align

* fix

* license

* fix

* fix

* fix readme

* fix some bugs

* remove tokenizer config
This commit is contained in:
Jianghai
2024-02-07 17:55:48 +08:00
committed by GitHub
parent 6fb4bcbb24
commit 1f8c7e7046
7 changed files with 62 additions and 23 deletions

View File

@@ -33,7 +33,7 @@ class InferenceEngine:
Args:
model (nn.Module): Path or nn.Module of this model.
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
@@ -42,19 +42,20 @@ class InferenceEngine:
def __init__(
self,
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: Optional["InferenceConfig"] = None,
tokenizer: [Union[PreTrainedTokenizer, PreTrainedTokenizerFast]],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
) -> None:
assert inference_config, "Please provide inference_config."
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
self.inference_config = inference_config
self.model_config = model.config
self.device = torch.device("cuda")
self.dtype = inference_config.dtype
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config.to_generation_config(self.model_config)
model = model.eval()
model.to(self.dtype)
@@ -80,6 +81,8 @@ class InferenceEngine:
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cahce, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
self.counter = count()
def _verify_config(self) -> None:
@@ -137,7 +140,7 @@ class InferenceEngine:
self,
prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
generation_config: GenerationConfig = None,
generation_config: Optional[GenerationConfig] = None,
) -> List[str]:
"""
Executing the inference step.
@@ -158,6 +161,10 @@ class InferenceEngine:
output_seqs_list = []
output_tokens_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
while self.request_handler.check_unfinished_seqs():
output_seqs_list += self.step()
@@ -285,8 +292,8 @@ class InferenceEngine:
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()
return finished_sequences

View File

@@ -2,6 +2,7 @@ from typing import List
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.flash_decoding_utils import FDIntermTensors
@@ -94,6 +95,10 @@ class RequestHandler:
head_dim = model_config.hidden_size // model_config.num_attention_heads
fd_inter_tensor = FDIntermTensors()
if fd_inter_tensor._tensors_initialized:
fd_inter_tensor._reset()
fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
num_attn_heads=model_config.num_attention_heads,
@@ -170,6 +175,7 @@ class RequestHandler:
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
for seq in remove_list:
lst.remove(seq)
if self.running_list.ready_for_prefill():
for seq in self.running_list.prefill:
seq.mark_running()
@@ -229,7 +235,7 @@ class RequestHandler:
return None
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config):
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
if generation_config.num_beams == 1:
if generation_config.do_sample:
sample_tokens = multinomial_sample(generation_config, probs)
@@ -240,7 +246,7 @@ class RequestHandler:
return sample_tokens
def mark_finished(self, sequence: Sequence, generation_config):
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig):
if (
sequence.output_token_id[-1] == generation_config.eos_id
or sequence.output_len >= generation_config.max_output_len
@@ -250,7 +256,7 @@ class RequestHandler:
def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
def search_tokens(self, generation_config, logits):
def search_tokens(self, generation_config: GenerationConfig, logits):
"""
Sample tokens for finished requests.
"""