[refactor] remove gpc dependency in colotensor's _ops (#1189)

This commit is contained in:
Jiarui Fang
2022-07-04 18:54:37 +08:00
committed by GitHub
parent abf6a262dc
commit 060b917daf
33 changed files with 499 additions and 357 deletions

View File

@@ -1,8 +1,6 @@
import torch
import itertools
import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.gemini.chunk import TensorState, Chunk
@@ -12,6 +10,7 @@ from typing import Dict, Iterable, List, Optional
from colossalai.logging import get_dist_logger
from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .reducer import Reducer
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
@@ -45,8 +44,8 @@ class ColoDDP(torch.nn.Module):
>>> from colossalai.core import global_context as gpc
>>> from colossalai.context import ParallelMode
>>> model = torch.nn.Linear(20, 1)
>>> model = ColoDDP(model)
>>> // model = ColoDDP(model, process_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA))
>>> pg = ProcessGroup(tp_degree = world_size//2)
>>> model = ColoDDP(model, pg)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
@@ -55,13 +54,13 @@ class ColoDDP(torch.nn.Module):
module (torch.nn.Module): Module to apply DDP.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
If it's None, the default data parallel group will be used. Defaults to None.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
If it's None, the default CPU data parallel group will be used. Defaults to None.
"""
def __init__(self,
module: torch.nn.Module,
process_group: Optional[dist.ProcessGroup] = None,
process_group: ColoProcessGroup,
cpu_process_group: Optional[dist.ProcessGroup] = None,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True) -> None:
@@ -69,8 +68,9 @@ class ColoDDP(torch.nn.Module):
super().__init__()
self.module = module
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
assert process_group
self.process_group = process_group.dp_process_group()
self.dp_world_size = self.process_group.size()
self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket
@@ -120,6 +120,8 @@ class ColoDDP(torch.nn.Module):
return empty_grad
else:
#TODO(jiaruifang) fixme
raise NotImplementedError
dist.all_reduce(grad, group=self.cpu_process_group)
return grad
@@ -191,8 +193,11 @@ class ZeroDDP(ColoDDP):
For more details, see the API reference of ``GeminiManager``.
"""
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
super().__init__(module.half())
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
process_group: Optional[ColoProcessGroup] = None) -> None:
super().__init__(module.half(), process_group=process_group)
self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager)

View File

@@ -52,5 +52,5 @@ class ColoModule(object):
def get_param_names(self):
return self._shard_params
def register(self, compute_pattern):
def register(self, compute_pattern, pg):
raise NotImplementedError

View File

@@ -1,5 +1,5 @@
from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
@@ -10,20 +10,18 @@ class ColoEmbedding(ColoModule):
super(ColoEmbedding, self).__init__()
self._register_shard_params(['weight'])
def register(self, compute_pattern):
def register(self, compute_pattern, pg: ProcessGroup):
if not compute_pattern in self._allowed_patterns:
if ComputePattern.TP1D == compute_pattern:
self._set_TP1D()
self._set_TP1D(pg)
def _set_TP1D(self):
def _set_TP1D(self, pg: ProcessGroup):
# TP1D Row Linear
_compute_pattern = ComputePattern.TP1D
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight':
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
},
mode='row',
)
@@ -32,9 +30,7 @@ class ColoEmbedding(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight':
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
},
mode='col',
)

View File

@@ -1,7 +1,5 @@
from .colo_module import ColoModule
from colossalai.tensor import ComputePattern, distspec
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
class ColoLinear(ColoModule):
@@ -10,22 +8,19 @@ class ColoLinear(ColoModule):
super(ColoLinear, self).__init__()
self._register_shard_params(['weight', 'bias'])
def register(self, compute_pattern):
def register(self, compute_pattern, pg: ProcessGroup):
if not compute_pattern in self._allowed_patterns:
if ComputePattern.TP1D == compute_pattern:
self._set_TP1D()
self._set_TP1D(pg)
def _set_TP1D(self):
def _set_TP1D(self, pg):
# TP1D Row Linear
_compute_pattern = ComputePattern.TP1D
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight':
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
'bias':
None
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
'bias': None
},
mode='row',
)
@@ -34,12 +29,8 @@ class ColoLinear(ColoModule):
self._register_allowed_patterns(
compute_pattern=_compute_pattern,
dist_specs={
'weight':
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
'bias':
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
[gpc.get_world_size(ParallelMode.PARALLEL_1D)])
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
'bias': distspec.shard(pg, [0], [pg.tp_world_size()])
},
mode='col',
)

View File

@@ -1,5 +1,5 @@
from typing import Dict
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec, ProcessGroup
from . import ColoModule
import torch
@@ -29,7 +29,7 @@ def get_colo_module(module: torch.nn.Module):
return None
def check_colo_module(module: torch.nn.Module, recursive=True):
def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True):
if is_colo_module(module):
colo_module = get_colo_module(module)
param_names = colo_module.get_param_names()
@@ -50,7 +50,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
continue
if compute_pattern is not None:
colo_module.register(compute_pattern)
colo_module.register(compute_pattern, pg)
if not colo_module.has_compute_pattern(compute_pattern):
raise Exception(
f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
@@ -76,16 +76,20 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
if recursive == True:
for submodule in module.children():
check_colo_module(submodule, recursive=True)
check_colo_module(submodule, pg=pg, recursive=True)
def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursive=True, mode='default'):
def init_colo_module(module: torch.nn.Module,
compute_spec: ComputeSpec,
pg: ProcessGroup,
recursive=True,
mode='default'):
compute_pattern = compute_spec.compute_pattern
if is_colo_module(module):
# for each param
# set DistSpec and ComputeSpec
colo_module = get_colo_module(module)
colo_module.register(compute_pattern)
colo_module.register(compute_pattern, pg)
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
raise NotImplementedError
# a set for modules which update at least one param in the init process.
@@ -101,7 +105,7 @@ def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursi
for mod in param.shared_param_modules:
modules_update_param.add(mod)
for mod in modules_update_param:
check_colo_module(mod, recursive=False)
check_colo_module(mod, pg, recursive=False)
if recursive == True:
for submodule in module.children():
init_colo_module(submodule, compute_spec, recursive=True, mode=mode)
init_colo_module(submodule, compute_spec, pg=pg, recursive=True, mode=mode)