[fix] merge conflicts

This commit is contained in:
Runyu Lu
2024-03-25 14:48:28 +08:00
15 changed files with 544 additions and 132 deletions

View File

@@ -88,7 +88,7 @@ class InferenceConfig:
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
# NOTE: arrange configs according to their importance and frequency of usage
@@ -122,6 +122,7 @@ class InferenceConfig:
pp_size: int = 1
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
high_precision: Optional[bool] = False
# cuda kernel option
use_cuda_kernel: bool = False
@@ -149,6 +150,10 @@ class InferenceConfig:
self.dtype in _ALLOWED_DTYPES
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
# skip using casting when the data type is float32
if self.dtype == torch.float32:
self.high_precision = False
# check distributed
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size()