mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-04 09:40:11 +00:00
[Tensor] add module check and bert test (#1031)
* add Embedding * Add bert test * polish * add check module test * polish * polish * polish * polish
This commit is contained in:
parent
7106bd671d
commit
6c5996a56e
@ -9,11 +9,11 @@ from .optim.colo_optimizer import ColoOptimizer
|
|||||||
from . import distspec
|
from . import distspec
|
||||||
from .dist_spec_mgr import DistSpecManager
|
from .dist_spec_mgr import DistSpecManager
|
||||||
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
||||||
from .modules import ColoLinear
|
from .modules import ColoLinear, ColoEmbedding
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
|
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
|
||||||
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager',
|
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager',
|
||||||
'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
|
'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
|
||||||
'ColoLinear'
|
'ColoLinear', 'ColoEmbedding'
|
||||||
]
|
]
|
||||||
|
@ -26,6 +26,13 @@ class ColoParameter(ColoTensor):
|
|||||||
self._type = TensorType.MODEL
|
self._type = TensorType.MODEL
|
||||||
self._graph_node = None
|
self._graph_node = None
|
||||||
|
|
||||||
|
# a list contains modules sharing this ColoParameter with others.
|
||||||
|
self._shared_param_modules = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shared_param_modules(self):
|
||||||
|
return self._shared_param_modules
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_torch_tensor(tensor: torch.Tensor,
|
def from_torch_tensor(tensor: torch.Tensor,
|
||||||
requires_grad: bool = True,
|
requires_grad: bool = True,
|
||||||
@ -36,3 +43,4 @@ class ColoParameter(ColoTensor):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
|
return f'ColoParameter: {torch.Tensor.__repr__(self)}'
|
||||||
|
|
||||||
|
@ -30,6 +30,12 @@ class _DistSpec:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
res = "\nDistSpec:\n\t"
|
||||||
|
for attr in dir(self):
|
||||||
|
if not attr.startswith('__'):
|
||||||
|
res += f'{attr}: {str(getattr(self, attr))}\n\t'
|
||||||
|
return res
|
||||||
|
|
||||||
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
|
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
|
||||||
# process_group=None means global process group
|
# process_group=None means global process group
|
||||||
|
@ -18,7 +18,6 @@ def get_colo_module(module: torch.nn.Module):
|
|||||||
global _COLOSSAL_MODULES
|
global _COLOSSAL_MODULES
|
||||||
if is_colo_module(module):
|
if is_colo_module(module):
|
||||||
colo_module = _COLOSSAL_MODULES[type(module)]
|
colo_module = _COLOSSAL_MODULES[type(module)]
|
||||||
colo_module.register()
|
|
||||||
return colo_module
|
return colo_module
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -43,6 +42,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if compute_pattern is not None:
|
if compute_pattern is not None:
|
||||||
|
colo_module.register(compute_pattern)
|
||||||
if not colo_module.has_compute_pattern(compute_pattern):
|
if not colo_module.has_compute_pattern(compute_pattern):
|
||||||
raise Exception(f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
|
raise Exception(f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
|
||||||
|
|
||||||
@ -65,28 +65,34 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
|||||||
break
|
break
|
||||||
if match_specs == False:
|
if match_specs == False:
|
||||||
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
|
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
|
||||||
|
|
||||||
if recursive == True:
|
if recursive == True:
|
||||||
for submodule in module.children():
|
for submodule in module.children():
|
||||||
check_colo_module(submodule, recursive=True)
|
check_colo_module(submodule, recursive=True)
|
||||||
|
|
||||||
def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, label='default'):
|
def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, mode='default'):
|
||||||
compute_pattern = parallel_action.compute_pattern
|
compute_pattern = parallel_action.compute_pattern
|
||||||
if is_colo_module(module):
|
if is_colo_module(module):
|
||||||
# for each param
|
# for each param
|
||||||
# set DistSpec and ParallelAction
|
# set DistSpec and ParallelAction
|
||||||
colo_module = get_colo_module(module)
|
colo_module = get_colo_module(module)
|
||||||
if not colo_module.has_compute_pattern_with_label(compute_pattern, label=label):
|
colo_module.register(compute_pattern)
|
||||||
|
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
for param_name, dist_spec in colo_module.get_dist_specs_with_label(compute_pattern, label=label).items():
|
# a set for modules which update at least one param in the init process.
|
||||||
|
# these modules need to be checked whether all params still match one of the valid compute pattern.
|
||||||
|
modules_update_param = {module}
|
||||||
|
for param_name, dist_spec in colo_module.get_dist_specs_with_mode(compute_pattern, mode=mode).items():
|
||||||
if dist_spec is None:
|
if dist_spec is None:
|
||||||
continue
|
continue
|
||||||
param = module.get_parameter(param_name)
|
param = module.get_parameter(param_name)
|
||||||
if isinstance(param, ColoParameter):
|
if isinstance(param, ColoParameter):
|
||||||
spec = TensorSpec(dist_spec, parallel_action)
|
spec = TensorSpec(dist_spec, parallel_action)
|
||||||
param.set_spec(spec)
|
param.set_spec(spec)
|
||||||
check_colo_module(module, recursive=False)
|
for mod in param.shared_param_modules:
|
||||||
|
modules_update_param.add(mod)
|
||||||
|
for mod in modules_update_param:
|
||||||
|
check_colo_module(mod, recursive=False)
|
||||||
if recursive == True:
|
if recursive == True:
|
||||||
for submodule in module.children():
|
for submodule in module.children():
|
||||||
init_colo_module(submodule, parallel_action, recursive=True, label=label)
|
init_colo_module(submodule, parallel_action, recursive=True, mode=mode)
|
||||||
|
|
@ -1,2 +1,3 @@
|
|||||||
from .colo_module import ColoModule
|
from .colo_module import ColoModule
|
||||||
from .linear import ColoLinear
|
from .linear import ColoLinear
|
||||||
|
from .embedding import ColoEmbedding
|
@ -21,14 +21,14 @@ class ColoModule(object):
|
|||||||
def _register_shard_params(self, params: List[str]):
|
def _register_shard_params(self, params: List[str]):
|
||||||
self._shard_params = params
|
self._shard_params = params
|
||||||
|
|
||||||
def _register_allowed_patterns(self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], label='default'):
|
def _register_allowed_patterns(self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode='default'):
|
||||||
assert list(dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
|
assert list(dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
|
||||||
if not compute_pattern in self._allowed_patterns:
|
if not compute_pattern in self._allowed_patterns:
|
||||||
self._allowed_patterns[compute_pattern] = {}
|
self._allowed_patterns[compute_pattern] = {}
|
||||||
self._allowed_patterns[compute_pattern][label] = dist_specs
|
self._allowed_patterns[compute_pattern][mode] = dist_specs
|
||||||
|
|
||||||
def _set_default(self, compute_pattern: ComputePattern, target_label):
|
def _set_default(self, compute_pattern: ComputePattern, target_mode):
|
||||||
self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_label]
|
self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode]
|
||||||
|
|
||||||
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
||||||
return compute_pattern in self._allowed_patterns
|
return compute_pattern in self._allowed_patterns
|
||||||
@ -37,15 +37,15 @@ class ColoModule(object):
|
|||||||
assert self.has_compute_pattern(compute_pattern)
|
assert self.has_compute_pattern(compute_pattern)
|
||||||
return self._allowed_patterns[compute_pattern]
|
return self._allowed_patterns[compute_pattern]
|
||||||
|
|
||||||
def has_compute_pattern_with_label(self, compute_pattern: ComputePattern, label='default'):
|
def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'):
|
||||||
return compute_pattern in self._allowed_patterns and label in self._allowed_patterns[compute_pattern]
|
return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern]
|
||||||
|
|
||||||
def get_dist_specs_with_label(self, compute_pattern: ComputePattern, label='default'):
|
def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'):
|
||||||
assert self.has_compute_pattern_with_label(compute_pattern, label)
|
assert self.has_compute_pattern_with_mode(compute_pattern, mode)
|
||||||
return self._allowed_patterns[compute_pattern][label]
|
return self._allowed_patterns[compute_pattern][mode]
|
||||||
|
|
||||||
def get_param_names(self):
|
def get_param_names(self):
|
||||||
return self._shard_params
|
return self._shard_params
|
||||||
|
|
||||||
def register(self):
|
def register(self, compute_pattern):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
36
colossalai/tensor/modules/embedding.py
Normal file
36
colossalai/tensor/modules/embedding.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from .colo_module import ColoModule
|
||||||
|
from colossalai.tensor import ComputePattern, distspec
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
|
||||||
|
class ColoEmbedding(ColoModule):
|
||||||
|
def __init__(self):
|
||||||
|
super(ColoEmbedding, self).__init__()
|
||||||
|
self._register_shard_params(['weight'])
|
||||||
|
|
||||||
|
def register(self, compute_pattern):
|
||||||
|
if not compute_pattern in self._allowed_patterns:
|
||||||
|
if ComputePattern.TP1D == compute_pattern:
|
||||||
|
self._set_TP1D()
|
||||||
|
|
||||||
|
def _set_TP1D(self):
|
||||||
|
# TP1D Row Linear
|
||||||
|
_compute_pattern = ComputePattern.TP1D
|
||||||
|
self._register_allowed_patterns(
|
||||||
|
compute_pattern=_compute_pattern,
|
||||||
|
dist_specs={
|
||||||
|
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
},
|
||||||
|
mode='row',
|
||||||
|
)
|
||||||
|
|
||||||
|
# TP1D Col Linear
|
||||||
|
self._register_allowed_patterns(
|
||||||
|
compute_pattern=_compute_pattern,
|
||||||
|
dist_specs={
|
||||||
|
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
},
|
||||||
|
mode='col',
|
||||||
|
)
|
||||||
|
|
||||||
|
self._set_default(compute_pattern=_compute_pattern, target_mode='row')
|
@ -7,12 +7,11 @@ class ColoLinear(ColoModule):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(ColoLinear, self).__init__()
|
super(ColoLinear, self).__init__()
|
||||||
self._register_shard_params(['weight', 'bias'])
|
self._register_shard_params(['weight', 'bias'])
|
||||||
self._register = False
|
|
||||||
|
|
||||||
def register(self):
|
def register(self, compute_pattern):
|
||||||
if self._register == False:
|
if not compute_pattern in self._allowed_patterns:
|
||||||
self._set_TP1D()
|
if ComputePattern.TP1D == compute_pattern:
|
||||||
self._register = True
|
self._set_TP1D()
|
||||||
|
|
||||||
def _set_TP1D(self):
|
def _set_TP1D(self):
|
||||||
# TP1D Row Linear
|
# TP1D Row Linear
|
||||||
@ -23,7 +22,7 @@ class ColoLinear(ColoModule):
|
|||||||
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
'bias': None
|
'bias': None
|
||||||
},
|
},
|
||||||
label='row',
|
mode='row',
|
||||||
)
|
)
|
||||||
|
|
||||||
# TP1D Col Linear
|
# TP1D Col Linear
|
||||||
@ -33,7 +32,7 @@ class ColoLinear(ColoModule):
|
|||||||
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
'bias': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)])
|
'bias': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)])
|
||||||
},
|
},
|
||||||
label='col',
|
mode='col',
|
||||||
)
|
)
|
||||||
|
|
||||||
self._set_default(compute_pattern=_compute_pattern, target_label='row')
|
self._set_default(compute_pattern=_compute_pattern, target_mode='row')
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from .utils import InsertPostInitMethodToModuleSubClasses
|
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \
|
from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \
|
||||||
ColoLinear
|
ColoLinear, ColoEmbedding
|
||||||
import types
|
import types
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -137,7 +137,12 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
||||||
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
||||||
torch.nn.Module.get_parameter = _get_parameter_with_colotensor
|
torch.nn.Module.get_parameter = _get_parameter_with_colotensor
|
||||||
|
|
||||||
|
self._register_colo_modules()
|
||||||
|
|
||||||
|
def _register_colo_modules(self):
|
||||||
register_colo_module(torch.nn.Linear, ColoLinear())
|
register_colo_module(torch.nn.Linear, ColoLinear())
|
||||||
|
register_colo_module(torch.nn.Embedding, ColoEmbedding())
|
||||||
|
|
||||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -179,5 +184,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
replaced_tensors[param] = colo_param
|
replaced_tensors[param] = colo_param
|
||||||
delattr(submodule, param_name)
|
delattr(submodule, param_name)
|
||||||
setattr(submodule, param_name, colo_param)
|
setattr(submodule, param_name, colo_param)
|
||||||
|
colo_param.shared_param_modules.append(submodule)
|
||||||
|
|
||||||
ColoModulize(module)
|
ColoModulize(module)
|
||||||
|
@ -15,21 +15,21 @@ import torch.nn.functional as F
|
|||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, register_colo_module, init_colo_module, ColoLinear
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, register_colo_module, init_colo_module, check_colo_module
|
||||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
def run_simplenet_with_spec(label):
|
def run_model_with_spec(mode, model_name):
|
||||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=False)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
model_seq = model_builder(checkpoint=True)
|
model_seq = model_builder(checkpoint=False)
|
||||||
model_seq = model_seq.cuda()
|
model_seq = model_seq.cuda()
|
||||||
|
|
||||||
# Make two models have the same init params
|
# Make two models have the same init params
|
||||||
@ -37,7 +37,19 @@ def run_simplenet_with_spec(label):
|
|||||||
p2.data.copy_(p1.data)
|
p2.data.copy_(p1.data)
|
||||||
|
|
||||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||||
init_colo_module(model, parallel_action, recursive=True, label=label)
|
# Not all layers in Bert can be mod by 4.
|
||||||
|
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
|
||||||
|
if 'bert' == model_name:
|
||||||
|
if 'col' == mode:
|
||||||
|
init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode=mode)
|
||||||
|
init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode)
|
||||||
|
init_colo_module(model.classifier, parallel_action, recursive=True, mode='row')
|
||||||
|
elif 'row' == mode:
|
||||||
|
init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode='col')
|
||||||
|
init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode)
|
||||||
|
init_colo_module(model.classifier, parallel_action, recursive=True, mode=mode)
|
||||||
|
elif 'simple_net' == model_name:
|
||||||
|
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||||
|
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
@ -91,14 +103,14 @@ def run_simplenet_with_spec(label):
|
|||||||
if i > 3:
|
if i > 3:
|
||||||
break
|
break
|
||||||
|
|
||||||
def run_linear_with_spec(label):
|
def run_linear_with_spec(mode):
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = torch.nn.Linear(4, 8)
|
model = torch.nn.Linear(4, 8)
|
||||||
|
|
||||||
model_handy = copy(model)
|
model_handy = copy(model)
|
||||||
|
|
||||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||||
init_colo_module(model, parallel_action, recursive=True, label=label)
|
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||||
|
|
||||||
x = torch.rand(2, 4).cuda()
|
x = torch.rand(2, 4).cuda()
|
||||||
out = model(x)
|
out = model(x)
|
||||||
@ -110,28 +122,79 @@ def run_linear_with_spec(label):
|
|||||||
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad)
|
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad)
|
||||||
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad)
|
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad)
|
||||||
|
|
||||||
|
def run_check_shared_param():
|
||||||
|
from transformers import BertForMaskedLM, BertConfig
|
||||||
|
hidden_dim = 8
|
||||||
|
num_head = 4
|
||||||
|
sequence_length = 12
|
||||||
|
num_layer = 2
|
||||||
|
vocab_size = 24
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, func):
|
config = BertConfig(vocab_size=vocab_size,
|
||||||
|
hidden_size=hidden_dim,
|
||||||
|
intermediate_size=hidden_dim * 4,
|
||||||
|
num_attention_heads=num_head,
|
||||||
|
max_position_embeddings=sequence_length,
|
||||||
|
num_hidden_layers=num_layer,
|
||||||
|
hidden_dropout_prob=0.,
|
||||||
|
attention_probs_dropout_prob=0.)
|
||||||
|
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
|
||||||
|
model = BertForMaskedLM(config)
|
||||||
|
|
||||||
|
model = model.cuda()
|
||||||
|
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||||
|
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
|
||||||
|
assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
|
||||||
|
# They are all Linear, so both row is allowed. This should pass check.
|
||||||
|
init_colo_module(model, parallel_action, recursive=True, mode='row')
|
||||||
|
# This should be detected by check because you can not set weight as row while set bias as col.
|
||||||
|
col_spec = TensorSpec(
|
||||||
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
ParallelAction(ComputePattern.TP1D))
|
||||||
|
model.cls.predictions.bias.set_spec(col_spec)
|
||||||
|
try:
|
||||||
|
check_colo_module(model.cls.predictions.decoder, recursive=False)
|
||||||
|
except Exception as e:
|
||||||
|
assert 'incorrectly sharded' in str(e)
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
func('col')
|
run_linear_with_spec('col')
|
||||||
func('row')
|
run_linear_with_spec('row')
|
||||||
func('default')
|
|
||||||
|
|
||||||
|
def run_dist_model(rank, world_size, port):
|
||||||
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
for model_name in ['simple_net', 'bert']:
|
||||||
|
run_model_with_spec('col', model_name)
|
||||||
|
run_model_with_spec('row', model_name)
|
||||||
|
|
||||||
|
def run_dist_check(rank, world_size, port):
|
||||||
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_check_shared_param()
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_module_linear_1d(world_size):
|
def test_module_linear_1d(world_size):
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), func=run_linear_with_spec)
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_module_simplenet(world_size):
|
def test_module_model(world_size):
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), func=run_simplenet_with_spec)
|
run_func = partial(run_dist_model, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 2])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_module_check(world_size):
|
||||||
|
run_func = partial(run_dist_check, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_module_simplenet(4)
|
test_module_check(2)
|
Loading…
Reference in New Issue
Block a user