mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 19:55:03 +00:00
[shardformer] adapted llama to the new API (#4036)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal
|
||||
|
||||
from colossalai.cluster.dist_coordinator import DistCoordinator
|
||||
|
||||
__all__ = ['ShardConfig']
|
||||
|
||||
@@ -19,9 +20,19 @@ class ShardConfig:
|
||||
gather_output (bool): Whether to gather the output of the model of the last layer
|
||||
"""
|
||||
tensor_parallel_size: int
|
||||
|
||||
# TODO: add support for tensor parallel
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
inference_only: bool = True
|
||||
gather_output: bool = True
|
||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
# inference_only: bool = True
|
||||
# gather_output: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ensure the parallel size can match the world size
|
||||
world_size = coordinator.world_size
|
||||
self.data_parallel_size = world_size // self.tensor_parallel_size
|
||||
assert world_size == self.data_parallel_size * self.tensor_parallel_size, \
|
||||
f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}"
|
||||
|
Reference in New Issue
Block a user