[refactor] move process group from _DistSpec to ColoTensor. (#1203)

This commit is contained in:
Jiarui Fang
2022-07-06 16:15:16 +08:00
committed by GitHub
parent 5da87ce35d
commit ae7d3f4927
34 changed files with 452 additions and 367 deletions

View File

@@ -21,7 +21,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
'weight': distspec.shard([0], [pg.tp_world_size()]),
},
mode='row',
)
@@ -30,7 +30,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
'weight': distspec.shard([-1], [pg.tp_world_size()]),
},
mode='col',
)

View File

@@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
'weight': distspec.shard([-1], [pg.tp_world_size()]),
'bias': None
},
mode='row',
@@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
'bias': distspec.shard(pg, [0], [pg.tp_world_size()])
'weight': distspec.shard([0], [pg.tp_world_size()]),
'bias': distspec.shard([0], [pg.tp_world_size()])
},
mode='col',
)

View File

@@ -1,5 +1,6 @@
from typing import Dict
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec, ProcessGroup
from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup
from colossalai.tensor import distspec
from . import ColoModule
import torch
@@ -39,7 +40,7 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True)
if not isinstance(param, ColoParameter):
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
if param.has_compute_spec():
cur_compute_pattern = param.tensor_spec.compute_spec.compute_pattern
cur_compute_pattern = param.compute_spec.compute_pattern
if compute_pattern is None:
compute_pattern = cur_compute_pattern
else:
@@ -62,7 +63,7 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True)
for param_name, dist_spec in param_specs.items():
param = module.get_parameter(param_name)
if param.has_compute_spec():
if dist_spec != param.tensor_spec.dist_spec:
if dist_spec != param.dist_spec:
cur_match = False
break
else:
@@ -100,8 +101,8 @@ def init_colo_module(module: torch.nn.Module,
continue
param = module.get_parameter(param_name)
if isinstance(param, ColoParameter):
spec = TensorSpec(dist_spec, compute_spec)
param.set_tensor_spec(spec)
param.set_dist_spec(dist_spec)
param.compute_spec = compute_spec
for mod in param.shared_param_modules:
modules_update_param.add(mod)
for mod in modules_update_param: