mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[Inference] User Experience: update the logic of default tokenizer and generation config. (#5337)
* add * fix * fix * pause * fix * fix pytest * align * fix * license * fix * fix * fix readme * fix some bugs * remove tokenizer config
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import List
|
||||
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
@@ -94,6 +95,10 @@ class RequestHandler:
|
||||
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
||||
|
||||
fd_inter_tensor = FDIntermTensors()
|
||||
|
||||
if fd_inter_tensor._tensors_initialized:
|
||||
fd_inter_tensor._reset()
|
||||
|
||||
fd_inter_tensor.initialize(
|
||||
max_batch_size=self.max_batch_size,
|
||||
num_attn_heads=model_config.num_attention_heads,
|
||||
@@ -170,6 +175,7 @@ class RequestHandler:
|
||||
self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len)
|
||||
for seq in remove_list:
|
||||
lst.remove(seq)
|
||||
|
||||
if self.running_list.ready_for_prefill():
|
||||
for seq in self.running_list.prefill:
|
||||
seq.mark_running()
|
||||
@@ -229,7 +235,7 @@ class RequestHandler:
|
||||
|
||||
return None
|
||||
|
||||
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config):
|
||||
def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig):
|
||||
if generation_config.num_beams == 1:
|
||||
if generation_config.do_sample:
|
||||
sample_tokens = multinomial_sample(generation_config, probs)
|
||||
@@ -240,7 +246,7 @@ class RequestHandler:
|
||||
|
||||
return sample_tokens
|
||||
|
||||
def mark_finished(self, sequence: Sequence, generation_config):
|
||||
def mark_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
||||
if (
|
||||
sequence.output_token_id[-1] == generation_config.eos_id
|
||||
or sequence.output_len >= generation_config.max_output_len
|
||||
@@ -250,7 +256,7 @@ class RequestHandler:
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def search_tokens(self, generation_config, logits):
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||
"""
|
||||
Sample tokens for finished requests.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user