mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
[autoparallel] Patch meta information for nodes that will not be handled by SPMD solver (#2823)
* [autoparallel] non spmd meta information generator * [autoparallel] patch meta information for non spmd nodes
This commit is contained in:
parent
c7764d3f22
commit
eae77c831d
@ -3,6 +3,7 @@ from .binary_elementwise_ops import *
|
|||||||
from .conv import *
|
from .conv import *
|
||||||
from .embedding import *
|
from .embedding import *
|
||||||
from .linear import *
|
from .linear import *
|
||||||
|
from .non_spmd import *
|
||||||
from .norm import *
|
from .norm import *
|
||||||
from .pooling import *
|
from .pooling import *
|
||||||
from .tensor import *
|
from .tensor import *
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
import operator
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
|
||||||
|
from colossalai.fx.profiler.memory_utils import activation_size
|
||||||
|
from colossalai.fx.profiler.opcount import flop_mapping
|
||||||
|
|
||||||
|
from ..registry import meta_register
|
||||||
|
|
||||||
|
__all__ = ["non_spmd_meta_info"]
|
||||||
|
|
||||||
|
|
||||||
|
@meta_register.register(torch.Size)
|
||||||
|
@meta_register.register(torch.Tensor.size)
|
||||||
|
@meta_register.register(torch.finfo)
|
||||||
|
@meta_register.register(operator.le)
|
||||||
|
def non_spmd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
|
||||||
|
"""Non-SPMD node meta information generator
|
||||||
|
Those nodes will not be handled by SPMD solver, so we just return all zero meta information for it
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||||
|
"""
|
||||||
|
compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||||
|
memory_cost = TrainCycleItem(fwd=MemoryCost(), bwd=MemoryCost(), total=MemoryCost())
|
||||||
|
fwd_in, fwd_buffer, fwd_out = [], [], []
|
||||||
|
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
|
Loading…
Reference in New Issue
Block a user