Merge pull request #2258 from hpcaitech/debug/ckpt-autoparallel

[autockpt] provide option for activation checkpoint search in SPMD solver
This commit is contained in:
Boyuan Yao
2023-01-04 11:37:28 +08:00
committed by GitHub
16 changed files with 300 additions and 182 deletions

View File

@@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
@operator_registry.register(BCAST_FUNC_OP)
class BinaryElementwiseHandler(NodeHandler):
class BinaryElementwiseHandler(MetaInfoNodeHandler):
"""
An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add.

View File

@@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@@ -138,8 +138,7 @@ class NodeHandler(ABC):
return None
if self.node.op == 'call_module':
submod = self.node.graph.owning_module.get_submodule(self.node.target)
target = type(submod)
target = self.node.graph.owning_module.get_submodule(self.node.target)
elif self.node.op == 'call_function':
target = self.node.target
elif self.node.op == 'call_method':
@@ -235,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler):
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector
@@ -282,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler):
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector

View File

@@ -3,7 +3,7 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import ReshapeGenerator, StrategyGenerator
@@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler']
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler):
class ReshapeHandler(MetaInfoNodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""

View File

@@ -3,7 +3,7 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
@@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler']
@operator_registry.register(torch.nn.modules.dropout.Dropout)
@operator_registry.register(torch.Tensor.contiguous)
@operator_registry.register(torch.nn.functional.dropout)
class UnaryElementwiseHandler(NodeHandler):
class UnaryElementwiseHandler(MetaInfoNodeHandler):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
"""