mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[hotfix] Suport extra_kwargs in ShardConfig (#5031)
* [refactor]: replace inference args with extra_kwargs in ShardConfig * modify shardconfig * polish code * fix policy bug in llama * fix bug in auto policy * remove setattr in ShardConfig
This commit is contained in:
@@ -81,8 +81,6 @@ Following are the description `ShardConfig`'s arguments:
|
||||
|
||||
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
|
||||
|
||||
- `inference_only`: Whether only doing forward passing. Defaults to False.
|
||||
|
||||
### Write your own policy
|
||||
|
||||
If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design).
|
||||
@@ -185,7 +183,6 @@ class ShardConfig:
|
||||
|
||||
# Some possible future config fields
|
||||
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
|
||||
inference_only: bool # only inject inference-suitable sharding policy
|
||||
use_flash_attention: bool # whether to use flash attention to speed up attention
|
||||
```
|
||||
|
||||
|
@@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
|
||||
:class:`Policy`: The auto policy for the model
|
||||
"""
|
||||
full_name = _fullname(model)
|
||||
if shard_config.inference_only:
|
||||
inference_only = shard_config.extra_kwargs.get("inference_only", False)
|
||||
if inference_only:
|
||||
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
||||
else:
|
||||
policy_location = _POLICY_LIST.get(full_name, None)
|
||||
@@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
||||
)
|
||||
else:
|
||||
policy = import_policy(policy_location, shard_config.inference_only)
|
||||
policy = import_policy(policy_location, inference_only)
|
||||
return policy()
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -24,7 +24,6 @@ class ShardConfig:
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
||||
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
||||
inference_only (bool): Whether only doing forward passing. Defaults to False.
|
||||
"""
|
||||
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
||||
@@ -33,10 +32,9 @@ class ShardConfig:
|
||||
enable_flash_attention: bool = False
|
||||
enable_jit_fused: bool = False
|
||||
enable_all_optimization: bool = False
|
||||
inference_only: bool = False
|
||||
inference_gptq: bool = False
|
||||
enable_sequence_parallelism: bool = False
|
||||
enable_sequence_overlap: bool = False
|
||||
extra_kwargs: Dict[str, bool] = field(default_factory=dict)
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
@@ -77,4 +75,3 @@ class ShardConfig:
|
||||
Set default params for inference.
|
||||
"""
|
||||
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
|
||||
pass
|
||||
|
Reference in New Issue
Block a user