mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[refactor] remove gpc dependency in colotensor's _ops (#1189)
This commit is contained in:
@@ -52,5 +52,5 @@ class ColoModule(object):
|
||||
def get_param_names(self):
|
||||
return self._shard_params
|
||||
|
||||
def register(self, compute_pattern):
|
||||
def register(self, compute_pattern, pg):
|
||||
raise NotImplementedError
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
@@ -10,20 +10,18 @@ class ColoEmbedding(ColoModule):
|
||||
super(ColoEmbedding, self).__init__()
|
||||
self._register_shard_params(['weight'])
|
||||
|
||||
def register(self, compute_pattern):
|
||||
def register(self, compute_pattern, pg: ProcessGroup):
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
if ComputePattern.TP1D == compute_pattern:
|
||||
self._set_TP1D()
|
||||
self._set_TP1D(pg)
|
||||
|
||||
def _set_TP1D(self):
|
||||
def _set_TP1D(self, pg: ProcessGroup):
|
||||
# TP1D Row Linear
|
||||
_compute_pattern = ComputePattern.TP1D
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
@@ -32,9 +30,7 @@ class ColoEmbedding(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
@@ -1,7 +1,5 @@
|
||||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
||||
|
||||
|
||||
class ColoLinear(ColoModule):
|
||||
@@ -10,22 +8,19 @@ class ColoLinear(ColoModule):
|
||||
super(ColoLinear, self).__init__()
|
||||
self._register_shard_params(['weight', 'bias'])
|
||||
|
||||
def register(self, compute_pattern):
|
||||
def register(self, compute_pattern, pg: ProcessGroup):
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
if ComputePattern.TP1D == compute_pattern:
|
||||
self._set_TP1D()
|
||||
self._set_TP1D(pg)
|
||||
|
||||
def _set_TP1D(self):
|
||||
def _set_TP1D(self, pg):
|
||||
# TP1D Row Linear
|
||||
_compute_pattern = ComputePattern.TP1D
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'bias':
|
||||
None
|
||||
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
|
||||
'bias': None
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
@@ -34,12 +29,8 @@ class ColoLinear(ColoModule):
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'bias':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)])
|
||||
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
|
||||
'bias': distspec.shard(pg, [0], [pg.tp_world_size()])
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from typing import Dict
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec, ProcessGroup
|
||||
from . import ColoModule
|
||||
import torch
|
||||
|
||||
@@ -29,7 +29,7 @@ def get_colo_module(module: torch.nn.Module):
|
||||
return None
|
||||
|
||||
|
||||
def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True):
|
||||
if is_colo_module(module):
|
||||
colo_module = get_colo_module(module)
|
||||
param_names = colo_module.get_param_names()
|
||||
@@ -50,7 +50,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
continue
|
||||
|
||||
if compute_pattern is not None:
|
||||
colo_module.register(compute_pattern)
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern(compute_pattern):
|
||||
raise Exception(
|
||||
f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
|
||||
@@ -76,16 +76,20 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
check_colo_module(submodule, recursive=True)
|
||||
check_colo_module(submodule, pg=pg, recursive=True)
|
||||
|
||||
|
||||
def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursive=True, mode='default'):
|
||||
def init_colo_module(module: torch.nn.Module,
|
||||
compute_spec: ComputeSpec,
|
||||
pg: ProcessGroup,
|
||||
recursive=True,
|
||||
mode='default'):
|
||||
compute_pattern = compute_spec.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set DistSpec and ComputeSpec
|
||||
colo_module = get_colo_module(module)
|
||||
colo_module.register(compute_pattern)
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
|
||||
raise NotImplementedError
|
||||
# a set for modules which update at least one param in the init process.
|
||||
@@ -101,7 +105,7 @@ def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursi
|
||||
for mod in param.shared_param_modules:
|
||||
modules_update_param.add(mod)
|
||||
for mod in modules_update_param:
|
||||
check_colo_module(mod, recursive=False)
|
||||
check_colo_module(mod, pg, recursive=False)
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
init_colo_module(submodule, compute_spec, recursive=True, mode=mode)
|
||||
init_colo_module(submodule, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
Reference in New Issue
Block a user