mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-11 18:01:05 +00:00
[refactor] remove gpc dependency in colotensor's _ops (#1189)
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import torch
|
||||
import itertools
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.gemini.chunk import TensorState, Chunk
|
||||
@@ -12,6 +10,7 @@ from typing import Dict, Iterable, List, Optional
|
||||
from colossalai.logging import get_dist_logger
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from .reducer import Reducer
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
@@ -45,8 +44,8 @@ class ColoDDP(torch.nn.Module):
|
||||
>>> from colossalai.core import global_context as gpc
|
||||
>>> from colossalai.context import ParallelMode
|
||||
>>> model = torch.nn.Linear(20, 1)
|
||||
>>> model = ColoDDP(model)
|
||||
>>> // model = ColoDDP(model, process_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA))
|
||||
>>> pg = ProcessGroup(tp_degree = world_size//2)
|
||||
>>> model = ColoDDP(model, pg)
|
||||
>>> logits = model(x)
|
||||
>>> loss = criterion(logits, labels)
|
||||
>>> model.backward(loss)
|
||||
@@ -55,13 +54,13 @@ class ColoDDP(torch.nn.Module):
|
||||
module (torch.nn.Module): Module to apply DDP.
|
||||
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
||||
If it's None, the default data parallel group will be used. Defaults to None.
|
||||
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
|
||||
cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
|
||||
If it's None, the default CPU data parallel group will be used. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
process_group: ColoProcessGroup,
|
||||
cpu_process_group: Optional[dist.ProcessGroup] = None,
|
||||
bucket_cap_mb: int = 25,
|
||||
rebuild_bucket: bool = True) -> None:
|
||||
@@ -69,8 +68,9 @@ class ColoDDP(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
|
||||
assert process_group
|
||||
|
||||
self.process_group = process_group.dp_process_group()
|
||||
self.dp_world_size = self.process_group.size()
|
||||
self.reducer = Reducer(bucket_cap_mb)
|
||||
self.rebuild_bucket = rebuild_bucket
|
||||
@@ -120,6 +120,8 @@ class ColoDDP(torch.nn.Module):
|
||||
return empty_grad
|
||||
|
||||
else:
|
||||
#TODO(jiaruifang) fixme
|
||||
raise NotImplementedError
|
||||
dist.all_reduce(grad, group=self.cpu_process_group)
|
||||
return grad
|
||||
|
||||
@@ -191,8 +193,11 @@ class ZeroDDP(ColoDDP):
|
||||
For more details, see the API reference of ``GeminiManager``.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__(module.half())
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
process_group: Optional[ColoProcessGroup] = None) -> None:
|
||||
super().__init__(module.half(), process_group=process_group)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
|
Reference in New Issue
Block a user