mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
[shardformer] fix submodule replacement bug when enabling pp (#4544)
This commit is contained in:
parent
ec18fc7340
commit
2c787d7f47
@ -92,22 +92,21 @@ class ModelSharder(object):
|
|||||||
param_replacement (List[Callable]): The function list to get parameter shard information in policy
|
param_replacement (List[Callable]): The function list to get parameter shard information in policy
|
||||||
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
|
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
|
||||||
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
|
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
|
||||||
|
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
|
||||||
"""
|
"""
|
||||||
# released layers are not shardable
|
|
||||||
can_replace_param_or_layer = include is None or module in include
|
|
||||||
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
|
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
|
||||||
(module.__class__ == origin_cls):
|
(module.__class__ == origin_cls):
|
||||||
if attr_replacement is not None:
|
if attr_replacement is not None:
|
||||||
self._replace_attr(module, attr_replacement)
|
self._replace_attr(module, attr_replacement)
|
||||||
|
|
||||||
if param_replacement is not None and can_replace_param_or_layer:
|
if param_replacement is not None and (include is None or module in include):
|
||||||
self._replace_param(module, param_replacement)
|
self._replace_param(module, param_replacement)
|
||||||
|
|
||||||
if method_replacement is not None:
|
if method_replacement is not None:
|
||||||
self._replace_method(module, method_replacement)
|
self._replace_method(module, method_replacement)
|
||||||
|
|
||||||
if sub_module_replacement is not None and can_replace_param_or_layer:
|
if sub_module_replacement is not None:
|
||||||
self._replace_sub_module(module, sub_module_replacement)
|
self._replace_sub_module(module, sub_module_replacement, include)
|
||||||
|
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
self._recursive_replace_layer(child,
|
self._recursive_replace_layer(child,
|
||||||
@ -154,18 +153,17 @@ class ModelSharder(object):
|
|||||||
bound_method = MethodType(new_method, module)
|
bound_method = MethodType(new_method, module)
|
||||||
setattr(module, method_name, bound_method)
|
setattr(module, method_name, bound_method)
|
||||||
|
|
||||||
def _replace_sub_module(
|
def _replace_sub_module(self,
|
||||||
self,
|
|
||||||
org_layer: nn.Module,
|
org_layer: nn.Module,
|
||||||
sub_module_replacement: List[SubModuleReplacementDescription],
|
sub_module_replacement: List[SubModuleReplacementDescription],
|
||||||
) -> None:
|
include: Optional[Set[nn.Module]] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
org_layer (torch.nn.Module): The origin layer object to shard
|
org_layer (torch.nn.Module): The origin layer object to shard
|
||||||
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
|
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
|
||||||
|
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
|
||||||
"""
|
"""
|
||||||
for description in sub_module_replacement:
|
for description in sub_module_replacement:
|
||||||
suffix = description.suffix
|
suffix = description.suffix
|
||||||
@ -174,9 +172,12 @@ class ModelSharder(object):
|
|||||||
|
|
||||||
assert target_module is not None, 'target_module should not be None'
|
assert target_module is not None, 'target_module should not be None'
|
||||||
|
|
||||||
# TODO: support different parallel mode
|
|
||||||
native_sub_module = getattr_(org_layer, suffix, ignore=True)
|
native_sub_module = getattr_(org_layer, suffix, ignore=True)
|
||||||
|
|
||||||
|
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
|
||||||
|
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
|
||||||
|
continue
|
||||||
|
|
||||||
assert not isinstance(native_sub_module, target_module), \
|
assert not isinstance(native_sub_module, target_module), \
|
||||||
f"The module with suffix {suffix} has been replaced, please check the policy"
|
f"The module with suffix {suffix} has been replaced, please check the policy"
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from utils import shared_tempdir
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import HybridParallelPlugin
|
from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import (
|
from colossalai.testing import (
|
||||||
check_state_dict_equal,
|
check_state_dict_equal,
|
||||||
@ -100,6 +101,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||||||
booster.load_model(new_model, model_ckpt_path)
|
booster.load_model(new_model, model_ckpt_path)
|
||||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
||||||
|
|
||||||
|
Randomizer.reset_index()
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from torch import distributed as dist
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
@ -105,6 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
# check grads
|
# check grads
|
||||||
check_all_grad_tensors(grads_to_check)
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
|
||||||
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from torch import distributed as dist
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
@ -97,6 +98,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
# check grads
|
# check grads
|
||||||
check_all_grad_tensors(grads_to_check)
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
|
||||||
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from torch import distributed as dist
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
@ -107,6 +108,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
# check grads
|
# check grads
|
||||||
check_all_grad_tensors(grads_to_check)
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
|
||||||
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user