mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fix] merge conflicts
This commit is contained in:
@@ -88,7 +88,7 @@ class InferenceConfig:
|
||||
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
||||
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
|
||||
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
# NOTE: arrange configs according to their importance and frequency of usage
|
||||
@@ -122,6 +122,7 @@ class InferenceConfig:
|
||||
pp_size: int = 1
|
||||
micro_batch_size: int = 1
|
||||
micro_batch_buffer_size: int = None
|
||||
high_precision: Optional[bool] = False
|
||||
|
||||
# cuda kernel option
|
||||
use_cuda_kernel: bool = False
|
||||
@@ -149,6 +150,10 @@ class InferenceConfig:
|
||||
self.dtype in _ALLOWED_DTYPES
|
||||
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
|
||||
|
||||
# skip using casting when the data type is float32
|
||||
if self.dtype == torch.float32:
|
||||
self.high_precision = False
|
||||
|
||||
# check distributed
|
||||
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
|
Reference in New Issue
Block a user