mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 18:39:56 +00:00
[refactor] move process group from _DistSpec to ColoTensor. (#1203)
This commit is contained in:
@@ -21,7 +21,7 @@ class ColoEmbedding(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
|
||||
'weight': distspec.shard([0], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
@@ -30,7 +30,7 @@ class ColoEmbedding(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
|
||||
'weight': distspec.shard([-1], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
@@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
|
||||
'weight': distspec.shard([-1], [pg.tp_world_size()]),
|
||||
'bias': None
|
||||
},
|
||||
mode='row',
|
||||
@@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
|
||||
'bias': distspec.shard(pg, [0], [pg.tp_world_size()])
|
||||
'weight': distspec.shard([0], [pg.tp_world_size()]),
|
||||
'bias': distspec.shard([0], [pg.tp_world_size()])
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from typing import Dict
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec, ProcessGroup
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup
|
||||
from colossalai.tensor import distspec
|
||||
from . import ColoModule
|
||||
import torch
|
||||
|
||||
@@ -39,7 +40,7 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True)
|
||||
if not isinstance(param, ColoParameter):
|
||||
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
|
||||
if param.has_compute_spec():
|
||||
cur_compute_pattern = param.tensor_spec.compute_spec.compute_pattern
|
||||
cur_compute_pattern = param.compute_spec.compute_pattern
|
||||
if compute_pattern is None:
|
||||
compute_pattern = cur_compute_pattern
|
||||
else:
|
||||
@@ -62,7 +63,7 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True)
|
||||
for param_name, dist_spec in param_specs.items():
|
||||
param = module.get_parameter(param_name)
|
||||
if param.has_compute_spec():
|
||||
if dist_spec != param.tensor_spec.dist_spec:
|
||||
if dist_spec != param.dist_spec:
|
||||
cur_match = False
|
||||
break
|
||||
else:
|
||||
@@ -100,8 +101,8 @@ def init_colo_module(module: torch.nn.Module,
|
||||
continue
|
||||
param = module.get_parameter(param_name)
|
||||
if isinstance(param, ColoParameter):
|
||||
spec = TensorSpec(dist_spec, compute_spec)
|
||||
param.set_tensor_spec(spec)
|
||||
param.set_dist_spec(dist_spec)
|
||||
param.compute_spec = compute_spec
|
||||
for mod in param.shared_param_modules:
|
||||
modules_update_param.add(mod)
|
||||
for mod in modules_update_param:
|
||||
|
Reference in New Issue
Block a user