mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[zero] add chunk init function for users (#1729)
* add chunk manager init function * fix unit tests * add comment * add flush=True
This commit is contained in:
@@ -1,21 +1,23 @@
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from functools import partial
|
||||
from colossalai.nn.parallel import ColoDDP, ZeroDDP
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Callable, Type
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
from typing import Callable, Type
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.nn.parallel import ColoDDP, ZeroDDP
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
@@ -33,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
|
||||
chunk_config = search_chunk_configuration(module, 4, 1024)
|
||||
chunk_config, _ = search_chunk_configuration(module, 4, 1024)
|
||||
chunk_manager = ChunkManager(chunk_config)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
return ZeroDDP(module, gemini_manager)
|
||||
|
Reference in New Issue
Block a user