mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-18 17:31:53 +00:00
parent
88759e289e
commit
e761ad2cd7
@ -1,109 +0,0 @@
|
|||||||
#include <cuda_runtime.h>
|
|
||||||
#include <nccl.h>
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
#define CHECK_CUDA(x) \
|
|
||||||
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
|
||||||
#define CHECK_CONTIGUOUS(x) \
|
|
||||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
|
||||||
#define CHECK_INPUT(x) \
|
|
||||||
CHECK_CUDA(x); \
|
|
||||||
CHECK_CONTIGUOUS(x)
|
|
||||||
|
|
||||||
#define CUDACHECK(cmd) \
|
|
||||||
do { \
|
|
||||||
cudaError_t e = cmd; \
|
|
||||||
if (e != cudaSuccess) { \
|
|
||||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
|
||||||
cudaGetErrorString(e)); \
|
|
||||||
exit(EXIT_FAILURE); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define NCCLCHECK(cmd) \
|
|
||||||
do { \
|
|
||||||
ncclResult_t r = cmd; \
|
|
||||||
if (r != ncclSuccess) { \
|
|
||||||
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
|
|
||||||
ncclGetErrorString(r)); \
|
|
||||||
exit(EXIT_FAILURE); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
class ZeroCommMgr {
|
|
||||||
public:
|
|
||||||
cudaStream_t cuda_stream;
|
|
||||||
ncclComm_t nccl_comm;
|
|
||||||
|
|
||||||
ZeroCommMgr(const ncclComm_t &comm_) {
|
|
||||||
CUDACHECK(cudaStreamCreate(&cuda_stream));
|
|
||||||
nccl_comm = comm_;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
ZeroCommMgr *GMGR = nullptr;
|
|
||||||
|
|
||||||
#ifdef USE_C10D_NCCL
|
|
||||||
#include <c10d/ProcessGroupNCCL.hpp>
|
|
||||||
|
|
||||||
class HackNCCLGroup : public c10d::ProcessGroupNCCL {
|
|
||||||
public:
|
|
||||||
ncclComm_t getcomm(at::Device dev) {
|
|
||||||
ncclUniqueId ncclID;
|
|
||||||
int rank = getRank();
|
|
||||||
if (rank == 0) {
|
|
||||||
ncclGetUniqueId(&ncclID);
|
|
||||||
}
|
|
||||||
|
|
||||||
broadcastUniqueNCCLID(&ncclID, c10d::OpType::SEND, "colossal_zero_comm",
|
|
||||||
rank);
|
|
||||||
|
|
||||||
ncclComm_t comm;
|
|
||||||
NCCLCHECK(ncclCommInitRank(&comm, getSize(), ncclID, rank));
|
|
||||||
return comm;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
int create_zero_comm(c10d::ProcessGroupNCCL &pg, at::Device dev) {
|
|
||||||
auto *hack_group = reinterpret_cast<HackNCCLGroup *>(&pg);
|
|
||||||
GMGR = new ZeroCommMgr(hack_group->getcomm(dev));
|
|
||||||
assert(GMGR->nccl_comm != 0);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
void colo_all_gather_impl(scalar_t *recvbuff, int rank, int sendcount,
|
|
||||||
ncclDataType_t data_type) {
|
|
||||||
scalar_t *sendbuff = recvbuff + (rank * sendcount);
|
|
||||||
NCCLCHECK(ncclAllGather(sendbuff, recvbuff, sendcount, data_type,
|
|
||||||
GMGR->nccl_comm, GMGR->cuda_stream));
|
|
||||||
CUDACHECK(cudaStreamSynchronize(GMGR->cuda_stream));
|
|
||||||
}
|
|
||||||
|
|
||||||
int colo_all_gather(torch::Tensor &input_tensor, int rank, int world_size) {
|
|
||||||
CHECK_INPUT(input_tensor);
|
|
||||||
|
|
||||||
auto total_size = input_tensor.numel();
|
|
||||||
assert(total_size % world_size == 0);
|
|
||||||
auto sendcount = total_size / world_size;
|
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
|
||||||
input_tensor.scalar_type(), "colo_all_gather", ([&] {
|
|
||||||
colo_all_gather_impl<scalar_t>(
|
|
||||||
input_tensor.data_ptr<scalar_t>(), rank, sendcount,
|
|
||||||
input_tensor.scalar_type() == at::ScalarType::Half ? ncclHalf
|
|
||||||
: ncclFloat);
|
|
||||||
}));
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
#ifdef USE_C10D_NCCL
|
|
||||||
m.def("create_comm", &create_zero_comm,
|
|
||||||
"Create the communication environment for Colossal Zero");
|
|
||||||
#endif
|
|
||||||
m.def("inplace_all_gather", &colo_all_gather,
|
|
||||||
"All gather operation used in Colossal Zero");
|
|
||||||
}
|
|
@ -1 +0,0 @@
|
|||||||
from .zero_comm import ZeroDist
|
|
@ -1,46 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
from colossalai.context.singleton_meta import SingletonMeta
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
ZERO_USE_NCCL = False
|
|
||||||
try:
|
|
||||||
import colossal_zero_comm
|
|
||||||
ZERO_USE_NCCL = True
|
|
||||||
except ImportError:
|
|
||||||
print("Please pip reinstall Colossalai.")
|
|
||||||
|
|
||||||
|
|
||||||
class ZeroCommWorld(metaclass=SingletonMeta):
|
|
||||||
"""Zero communicator, used for communications in zero parallel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.zero_pg: Optional[ProcessGroup] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_initialized(self):
|
|
||||||
return self.zero_pg is not None
|
|
||||||
|
|
||||||
def zero_comm_init(self, comm_group: ProcessGroup):
|
|
||||||
if not ZERO_USE_NCCL:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.is_initialized:
|
|
||||||
assert self.zero_pg == comm_group, "Cant not initialize zero group twice"
|
|
||||||
return
|
|
||||||
|
|
||||||
self.zero_pg = comm_group
|
|
||||||
colossal_zero_comm.create_comm(self.zero_pg, get_current_device())
|
|
||||||
|
|
||||||
def zero_all_gather(self, input_tensor: torch.Tensor):
|
|
||||||
assert self.zero_pg is not None, "Please initialize zero communication world first"
|
|
||||||
rank = dist.get_rank(self.zero_pg)
|
|
||||||
world_size = self.zero_pg.size()
|
|
||||||
colossal_zero_comm.inplace_all_gather(input_tensor, rank, world_size)
|
|
||||||
|
|
||||||
|
|
||||||
ZeroDist = ZeroCommWorld()
|
|
@ -12,7 +12,6 @@ from colossalai.logging import get_dist_logger
|
|||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
from colossalai.zero.comm import ZeroDist
|
|
||||||
from contextlib import AbstractContextManager
|
from contextlib import AbstractContextManager
|
||||||
|
|
||||||
|
|
||||||
@ -192,7 +191,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||||||
The Callback function when entering the context
|
The Callback function when entering the context
|
||||||
"""
|
"""
|
||||||
self.logger = get_dist_logger("ZeroInitContext")
|
self.logger = get_dist_logger("ZeroInitContext")
|
||||||
ZeroDist.zero_comm_init(self.dp_process_group) # initialize zero communication world
|
|
||||||
|
|
||||||
# substitute fan-in and fan-out calculation
|
# substitute fan-in and fan-out calculation
|
||||||
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
|
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from .base_shard_strategy import BaseShardStrategy
|
from .base_shard_strategy import BaseShardStrategy
|
||||||
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||||
from .tensor_shard_strategy import TensorShardStrategy
|
from .tensor_shard_strategy import TensorShardStrategy
|
||||||
from .zero_tensor_shard_strategy import ZeroTensorShardStrategy
|
|
||||||
|
|
||||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'ZeroTensorShardStrategy']
|
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
||||||
|
@ -1,38 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
|
||||||
from colossalai.zero.comm import ZeroDist
|
|
||||||
|
|
||||||
from .tensor_shard_strategy import TensorShardStrategy
|
|
||||||
|
|
||||||
|
|
||||||
class ZeroTensorShardStrategy(TensorShardStrategy):
|
|
||||||
"""Use the same shard scheme as `TensorShardStrategy`'s.
|
|
||||||
But its all-gather operation is in-place, meaning that no extra buffer is created.
|
|
||||||
Extra buffer is created when using `torch.distributed.all_gather`.
|
|
||||||
This can reduce peak memory used in zero-offload.
|
|
||||||
You should notice that this strategy is highly coupled with zero.
|
|
||||||
You can not change its communication group and must use ZeroContext to create your model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
|
||||||
if not t.is_sharded:
|
|
||||||
return
|
|
||||||
target_device = t.device
|
|
||||||
payload_numel = t.payload.numel()
|
|
||||||
world_size = dist.get_world_size(process_group)
|
|
||||||
rank = dist.get_rank(process_group)
|
|
||||||
|
|
||||||
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
|
|
||||||
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
|
|
||||||
buffer_list[rank].copy_(t.payload)
|
|
||||||
|
|
||||||
ZeroDist.zero_all_gather(buffer) # notice: process_group is useless here
|
|
||||||
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
|
|
||||||
t.reset_payload(gathered_payload)
|
|
||||||
colo_model_data_tensor_move_inline(t, target_device)
|
|
||||||
t.is_sharded = False
|
|
6
setup.py
6
setup.py
@ -134,12 +134,6 @@ if build_cuda_ext:
|
|||||||
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)
|
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)
|
||||||
})
|
})
|
||||||
|
|
||||||
ext_modules.append(
|
|
||||||
cuda_ext_helper(name='colossal_zero_comm',
|
|
||||||
sources=['zero_comm.cpp'],
|
|
||||||
extra_cuda_flags=['-DUSE_C10D_NCCL'],
|
|
||||||
extra_cxx_flags=['-DUSE_C10D_NCCL']))
|
|
||||||
|
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
cuda_ext_helper('colossal_C', [
|
cuda_ext_helper('colossal_C', [
|
||||||
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
|
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
|
||||||
|
@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam
|
|||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import ZeroTensorShardStrategy
|
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||||
@ -20,7 +20,7 @@ from common import CONFIG
|
|||||||
|
|
||||||
|
|
||||||
@parameterize("cpu_offload", [True, False])
|
@parameterize("cpu_offload", [True, False])
|
||||||
@parameterize("shard_strategy_class", [ZeroTensorShardStrategy])
|
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||||
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
|
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
|
||||||
def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio):
|
def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio):
|
||||||
test_models = ['repeated_computed_layers']
|
test_models = ['repeated_computed_layers']
|
||||||
|
@ -15,14 +15,14 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
|||||||
colo_model_mem_usage
|
colo_model_mem_usage
|
||||||
from colossalai.utils.memory import colo_device_memory_used
|
from colossalai.utils.memory import colo_device_memory_used
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy)
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
from common import CONFIG
|
from common import CONFIG
|
||||||
|
|
||||||
|
|
||||||
@parameterize("init_device_type", ['cpu', 'cuda'])
|
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy])
|
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
def run_model_test(init_device_type, shard_strategy_class):
|
def run_model_test(init_device_type, shard_strategy_class):
|
||||||
logger = get_dist_logger("test_zero_init")
|
logger = get_dist_logger("test_zero_init")
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from colossalai.utils.cuda import get_current_device
|
|||||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.shard_utils import ZeroTensorShardStrategy
|
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -35,7 +35,7 @@ def run_mem_collector_testing():
|
|||||||
fraction = (50 * 1024**2) / cuda_capacity
|
fraction = (50 * 1024**2) / cuda_capacity
|
||||||
# limit max memory to 50MB
|
# limit max memory to 50MB
|
||||||
colo_set_process_memory_fraction(fraction)
|
colo_set_process_memory_fraction(fraction)
|
||||||
shard_strategy = ZeroTensorShardStrategy()
|
shard_strategy = BucketTensorShardStrategy()
|
||||||
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
|
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
|
||||||
model = MyTestModel()
|
model = MyTestModel()
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ import torch.multiprocessing as mp
|
|||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, ZeroTensorShardStrategy)
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||||
@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
|
|||||||
|
|
||||||
|
|
||||||
@parameterize("enable_autocast", [True])
|
@parameterize("enable_autocast", [True])
|
||||||
@parameterize("shard_strategy_class", [ZeroTensorShardStrategy, BucketTensorShardStrategy])
|
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||||
def run_model_test(enable_autocast, shard_strategy_class):
|
def run_model_test(enable_autocast, shard_strategy_class):
|
||||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
|
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
|
||||||
shard_strategy = shard_strategy_class()
|
shard_strategy = shard_strategy_class()
|
||||||
|
@ -11,7 +11,7 @@ import torch.multiprocessing as mp
|
|||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy)
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
@ -19,7 +19,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||||||
from common import CONFIG
|
from common import CONFIG
|
||||||
|
|
||||||
|
|
||||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy])
|
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
def run_zero_state_dict(shard_strategy_class):
|
def run_zero_state_dict(shard_strategy_class):
|
||||||
test_models = ['repeated_computed_layers', 'resnet18']
|
test_models = ['repeated_computed_layers', 'resnet18']
|
||||||
shard_strategy = shard_strategy_class()
|
shard_strategy = shard_strategy_class()
|
||||||
|
Loading…
Reference in New Issue
Block a user