mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fix] pytest and fix dyn grid bug
This commit is contained in:
@@ -10,6 +10,8 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.generation import GenerationConfig
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
|
||||
GibiByte = 1024**3
|
||||
|
||||
logger = logging.Logger(__name__)
|
||||
@@ -45,13 +47,16 @@ class InputMetaData:
|
||||
|
||||
block_tables: torch.Tensor = None
|
||||
sequence_lengths: torch.Tensor = None
|
||||
fd_inter_tensor: torch.Tensor = None
|
||||
fd_inter_tensor: FDIntermTensors = None
|
||||
batch_size: int = 64 # current_batch_size
|
||||
is_prompts: bool = False
|
||||
use_cuda_graph: bool = False
|
||||
kv_seq_len: int = 512
|
||||
head_dim: int = 32
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"InputMetaData(block_tables={self.block_tables}, sequence_lengths={self.sequence_lengths}, fd_inter_tensor={self.fd_inter_tensor}, batch_size={self.batch_size}, is_prompts={self.is_prompts}, use_cuda_graph={self.use_cuda_graph}, kv_seq_len={self.kv_seq_len}, head_dim={self.head_dim})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
@@ -117,9 +122,10 @@ class InferenceConfig:
|
||||
|
||||
# cuda_graph
|
||||
use_cuda_graph: bool = False
|
||||
max_context_len_to_capture: int = max_input_len * max_output_len
|
||||
max_context_len_to_capture: int = 512
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
|
||||
self._verify_config()
|
||||
|
||||
def _verify_config(self) -> None:
|
||||
|
Reference in New Issue
Block a user