[hotfix] Suport extra_kwargs in ShardConfig (#5031)

* [refactor]: replace inference args with extra_kwargs in ShardConfig

* modify shardconfig

* polish code

* fix policy bug in llama

* fix bug in auto policy

* remove setattr in ShardConfig
This commit is contained in:
Zhongkai Zhao
2023-11-10 10:49:50 +08:00
committed by GitHub
parent 576a2f7b10
commit 70885d707d
23 changed files with 98 additions and 77 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch.distributed as dist
from torch.distributed import ProcessGroup
@@ -24,7 +24,6 @@ class ShardConfig:
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
inference_only (bool): Whether only doing forward passing. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
@@ -33,10 +32,9 @@ class ShardConfig:
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_all_optimization: bool = False
inference_only: bool = False
inference_gptq: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
extra_kwargs: Dict[str, bool] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@@ -77,4 +75,3 @@ class ShardConfig:
Set default params for inference.
"""
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
pass