mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
131
colossalai/auto_parallel/meta_profiler/shard_metainfo.py
Normal file
131
colossalai/auto_parallel/meta_profiler/shard_metainfo.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION
|
||||
from .registry import meta_register
|
||||
|
||||
__all__ = ['ShardMetaInfo']
|
||||
|
||||
|
||||
class ShardMetaInfo:
|
||||
"""ShardMetaInfo class
|
||||
This class is used to store meta info based on sharding strategy and the given
|
||||
target function.
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) -> None:
|
||||
# compute cost of forward and backward computation
|
||||
self.compute_cost: TrainCycleItem
|
||||
|
||||
# compute memory cost of forward and backward phase
|
||||
self.memory_cost: TrainCycleItem
|
||||
|
||||
# list of input tensors
|
||||
self.fwd_in: List[torch.Tensor]
|
||||
|
||||
# list of buffer tensors
|
||||
self.fwd_buffer: List[torch.Tensor]
|
||||
|
||||
# list of output tensors
|
||||
self.fwd_out: List[torch.Tensor]
|
||||
|
||||
# sharding strategy
|
||||
self._strategy = strategy
|
||||
|
||||
# target function
|
||||
self._target = target
|
||||
|
||||
# compute shard_metainfo if possible
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_shard_metainfo()
|
||||
|
||||
@property
|
||||
def strategy(self) -> ShardingStrategy:
|
||||
return self._strategy
|
||||
|
||||
@property
|
||||
def target(self) -> Callable:
|
||||
return self._target
|
||||
|
||||
@strategy.setter
|
||||
def strategy(self, strategy: ShardingStrategy) -> None:
|
||||
self._strategy = strategy
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_shard_metainfo()
|
||||
|
||||
@target.setter
|
||||
def target(self, target: Callable) -> None:
|
||||
self._target = target
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_shard_metainfo()
|
||||
|
||||
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
|
||||
"""
|
||||
Compute sharded opdata based on the given data and sharding spec.
|
||||
"""
|
||||
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
op_data = OperationData(name=operation_data.name,
|
||||
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape)
|
||||
elif isinstance(sharding_spec, (list, tuple)):
|
||||
data = operation_data.data
|
||||
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
|
||||
assert len(data) == len(sharding_spec), f"Length of data and sharding spec should be the same."
|
||||
sharded_data = []
|
||||
for d, s in zip(data, sharding_spec):
|
||||
sharded_data.append(torch.zeros(s.get_sharded_shape_per_device(), device="meta"))
|
||||
op_data = OperationData(name=operation_data.name, data=sharded_data, type=operation_data.type)
|
||||
else:
|
||||
raise ValueError(f"Sharding spec should be ShardingSpec or list, but got {type(sharding_spec)}.")
|
||||
|
||||
return op_data
|
||||
|
||||
def compute_shard_metainfo(self):
|
||||
"""
|
||||
Compute meta info based on sharding strategy and the given target function.
|
||||
"""
|
||||
assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \
|
||||
f"Meta info for {self._target} is not registered."
|
||||
if meta_register.has(self._target.__class__):
|
||||
# module
|
||||
meta_func = meta_register.get(self._target.__class__)
|
||||
|
||||
# check whether the target in the list that we don't need to save activation
|
||||
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||
else:
|
||||
# function
|
||||
meta_func = meta_register.get(self._target)
|
||||
|
||||
# check whether the target in the list that we don't need to save activation
|
||||
save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
|
||||
|
||||
# construct args for meta_func
|
||||
args = [self.compute_sharded_opdata(k, v) for k, v in self._strategy.sharding_specs.items()]
|
||||
|
||||
# construct kwargs
|
||||
if self.target in INPLACE_MODULE:
|
||||
kwargs = {'inplace': self.target.inplace}
|
||||
elif self.target in INPLACE_OPS:
|
||||
kwargs = {'inplace': True}
|
||||
else:
|
||||
kwargs = {'inplace': False}
|
||||
|
||||
# compute metainfo with meta_func
|
||||
self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs)
|
||||
|
||||
# process corner case for NO_SAVE_ACTIVATION
|
||||
if not save_fwd_in:
|
||||
self.fwd_in = []
|
Reference in New Issue
Block a user