[zero] add chunk init function for users (#1729)

* add chunk manager init function

* fix unit tests

* add comment

* add flush=True
This commit is contained in:
HELSON
2022-10-18 16:31:22 +08:00
committed by GitHub
parent 2e1dbfb463
commit f69f9bf223
10 changed files with 691 additions and 629 deletions

View File

@@ -1,21 +1,23 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable, Type
import torch.distributed as dist
import os
import random
from functools import partial
from typing import Callable, Type
import numpy as np
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
def set_seed(seed):
@@ -33,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
chunk_config = search_chunk_configuration(module, 4, 1024)
chunk_config, _ = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)