[refactor] remove gpc dependency in colotensor's _ops (#1189)

This commit is contained in:
Jiarui Fang
2022-07-04 18:54:37 +08:00
committed by GitHub
parent abf6a262dc
commit 060b917daf
33 changed files with 499 additions and 357 deletions

View File

@@ -15,6 +15,7 @@ import torch.distributed as dist
import os
import random
import numpy as np
from colossalai.tensor import ProcessGroup
def set_seed(seed):
@@ -27,14 +28,16 @@ def set_seed(seed):
def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module)
pg = ProcessGroup()
return ColoDDP(module, process_group=pg)
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
chunk_manager = ChunkManager(chunk_size)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)
pg = ProcessGroup()
return ZeroDDP(module, gemini_manager, pg)
class Net(torch.nn.Module):

View File

@@ -13,6 +13,7 @@ from colossalai.nn.parallel import ZeroDDP, ColoDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable
from collections import OrderedDict
from colossalai.tensor import ProcessGroup
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
@@ -22,14 +23,16 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module)
pg = ProcessGroup()
return ColoDDP(module, process_group=pg)
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)
pg = ProcessGroup()
return ZeroDDP(module, gemini_manager, process_group=pg)
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):