[zero] add chunk init function for users (#1729)

* add chunk manager init function

* fix unit tests

* add comment

* add flush=True
This commit is contained in:
HELSON 2022-10-18 16:31:22 +08:00 committed by GitHub
parent 2e1dbfb463
commit f69f9bf223
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 691 additions and 629 deletions

View File

@ -1,3 +1,4 @@
from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
from .manager import ChunkManager from .manager import ChunkManager
from .search_utils import clasify_params, search_chunk_configuration from .search_utils import clasify_params, search_chunk_configuration
from .utils import init_chunk_manager

View File

@ -1,100 +1,108 @@
import math import math
from typing import Dict, List from typing import Dict, List, Tuple
import numpy as np
import torch.nn as nn import numpy as np
from colossalai.tensor import ColoParameter import torch.nn as nn
from colossalai.tensor import ColoParameter
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""Filter those parameters whose size is too large from others.
""" def in_ddp(param: nn.Parameter) -> bool:
params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)] return not getattr(param, '_ddp_to_ignore', False)
params_size_arr = np.array(params_size)
std = np.std(params_size_arr) def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
mean = np.mean(params_size_arr) """Filter those parameters whose size is too large from others.
upper_limit = mean + 3 * std """
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
for key in size_dict: params_size_arr = np.array(params_size)
org_list = size_dict[key]
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list)) std = np.std(params_size_arr)
mean = np.mean(params_size_arr)
upper_limit = mean + 3 * std
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
"""Get unused byte for a certain chunk size. for key in size_dict:
""" org_list = size_dict[key]
acc = 0 size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
left = 0
for s in size_list:
if s > left: def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
acc += left """Get unused byte for a certain chunk size.
left = chunk_size """
left -= s acc = 0
return left + acc left = 0
for s in size_list:
if s > left:
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: acc += left
params_dict: Dict[int, List[ColoParameter]] = dict() left = chunk_size
for param in model.parameters(): left -= s
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" return left + acc
if getattr(param, '_ddp_to_ignore', False):
continue
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
param_key = param.process_group.dp_world_size() """Clasify each parameter by its size of DP group.
"""
if param_key not in params_dict: params_dict: Dict[int, List[ColoParameter]] = dict()
params_dict[param_key] = [] for param in model.parameters():
params_dict[param_key].append(param) assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if not in_ddp(param):
return params_dict continue
param_key = param.process_group.dp_world_size()
def search_chunk_configuration(
model: nn.Module, if param_key not in params_dict:
search_range_mb: float, params_dict[param_key] = []
search_interval_byte: int, # hidden size is the best value for the interval params_dict[param_key].append(param)
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True) -> Dict: return params_dict
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0 def search_chunk_configuration(
model: nn.Module,
params_dict = clasify_params(model) search_range_mb: float,
config_dict: Dict[int, Dict] = dict() search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
size_dict: Dict[int, List[int]] = dict() filter_exlarge_params: bool = True) -> Tuple[Dict, int]:
for key in params_dict: search_range_byte = round(search_range_mb * 1024**2)
params_list = params_dict[key] min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
size_list = [p.numel() for p in params_list] assert search_range_byte >= 0
# let small parameters keep gathered in CUDA all the time
total_size = sum(size_list) params_dict = clasify_params(model)
if total_size < min_chunk_size_byte: config_dict: Dict[int, Dict] = dict()
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
else: size_dict: Dict[int, List[int]] = dict()
size_dict[key] = size_list for key in params_dict:
params_list = params_dict[key]
if filter_exlarge_params: size_list = [p.numel() for p in params_list]
_filter_exlarge_params(model, size_dict) # let small parameters keep gathered in CUDA all the time
total_size = sum(size_list)
max_size = min_chunk_size_byte if total_size < min_chunk_size_byte:
for key in size_dict: config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
max_size = max(max_size, max(size_dict[key])) else:
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) size_dict[key] = size_list
min_chunk_waste = float('+inf') if filter_exlarge_params:
best_chunk_size = start_size _filter_exlarge_params(model, size_dict)
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): max_size = min_chunk_size_byte
temp_waste = 0 for key in size_dict:
for key in size_dict: max_size = max(max_size, max(size_dict[key]))
temp_waste += _get_unused_byte(size_dict[key], chunk_size) start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
if temp_waste < min_chunk_waste:
min_chunk_waste = temp_waste min_chunk_waste = float('+inf')
best_chunk_size = chunk_size best_chunk_size = start_size
for key in params_dict: for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
if key in config_dict: temp_waste = 0
continue for key in size_dict:
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False) temp_waste += _get_unused_byte(size_dict[key], chunk_size)
if temp_waste < min_chunk_waste:
return config_dict min_chunk_waste = temp_waste
best_chunk_size = chunk_size
for key in params_dict:
if key in config_dict:
continue
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)
return config_dict, min_chunk_waste

View File

@ -0,0 +1,58 @@
from time import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
search_range_mb: Optional[float] = None,
min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
kwargs_dict = dict()
if hidden_dim:
search_interval_byte = hidden_dim
else:
search_interval_byte = 1024 # 1kb
kwargs_dict["search_interval_byte"] = search_interval_byte
if search_range_mb:
kwargs_dict["search_range_mb"] = search_range_mb
if min_chunk_size_mb:
kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb
if filter_exlarge_params:
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)]
total_size = sum(params_sizes) / 1024**2
dist.barrier()
begine = time()
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
dist.barrier()
end = time()
span_s = end - begine
wasted_size /= 1024**2
if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
sep='',
flush=True)
dist.barrier()
chunk_manager = ChunkManager(config_dict, init_device)
return chunk_manager

View File

@ -1,21 +1,23 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable, Type
import torch.distributed as dist
import os import os
import random import random
from functools import partial
from typing import Callable, Type
import numpy as np import numpy as np
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
def set_seed(seed): def set_seed(seed):
@ -33,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP: def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
chunk_config = search_chunk_configuration(module, 4, 1024) chunk_config, _ = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config) chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager) gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager) return ZeroDDP(module, gemini_manager)

View File

@ -1,105 +1,104 @@
import pytest from functools import partial
import colossalai
import torch import pytest
import torch.multiprocessing as mp import torch
from colossalai.testing import rerun_if_address_is_in_use import torch.multiprocessing as mp
from colossalai.utils.cuda import get_current_device from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext import colossalai
from colossalai.amp import convert_to_apex_amp
from functools import partial from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal from colossalai.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.nn.parallel import ZeroDDP
from torch.nn.parallel import DistributedDataParallel as DDP from colossalai.tensor import ProcessGroup
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.nn.parallel import ZeroDDP from colossalai.utils import free_port
from colossalai.testing import parameterize from colossalai.utils.cuda import get_current_device
from colossalai.amp import convert_to_apex_amp from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini.gemini_mgr import GeminiManager from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.tensor import ProcessGroup from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal
from tests.test_tensor.common_utils import debug_print
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): chunk_manager = model.chunk_manager
chunk_manager = model.chunk_manager param_list = [p for p in model.parameters()]
param_list = [p for p in model.parameters()] chunk_list = chunk_manager.get_chunks(param_list)
chunk_list = chunk_manager.get_chunks(param_list) for chunk in chunk_list:
for chunk in chunk_list: chunk_manager.access_chunk(chunk)
chunk_manager.access_chunk(chunk)
for (p0, p1) in zip(model.parameters(), torch_model.parameters()):
for (p0, p1) in zip(model.parameters(), torch_model.parameters()): assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item())
assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item())
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): optimizer.zero_grad()
optimizer.zero_grad() logits = model(input_ids, attn_mask)
logits = model(input_ids, attn_mask) logits = logits.float()
logits = logits.float() loss = criterion(logits, input_ids)
loss = criterion(logits, input_ids) optimizer.backward(loss)
optimizer.backward(loss) return logits
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) def exam_gpt_fwd_bwd(placement_policy):
def exam_gpt_fwd_bwd(placement_policy): set_seed(42)
set_seed(42) get_components_func = non_distributed_component_funcs.get_callable('gpt2')
get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
with ColoInitContext(device=get_current_device()): model = model_builder()
model = model_builder()
torch_model = model_builder().cuda()
torch_model = model_builder().cuda() for torch_p, p in zip(torch_model.parameters(), model.parameters()):
for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p.data)
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
world_size = torch.distributed.get_world_size() config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = False
config_dict[world_size]['keep_gathered'] = False chunk_manager = ChunkManager(config_dict)
chunk_manager = ChunkManager(config_dict) gemini_manager = GeminiManager(placement_policy, chunk_manager)
gemini_manager = GeminiManager(placement_policy, chunk_manager) model = ZeroDDP(model, gemini_manager, pin_memory=True)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
pg = ProcessGroup()
pg = ProcessGroup() amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
model.eval()
model.eval() torch_model.eval()
torch_model.eval()
set_seed(pg.dp_local_rank())
set_seed(pg.dp_local_rank()) for i, (input_ids, attn_mask) in enumerate(train_dataloader):
for i, (input_ids, attn_mask) in enumerate(train_dataloader): if i > 0:
if i > 0: break
break
logits = model(input_ids, attn_mask)
logits = model(input_ids, attn_mask) logits = logits.float()
logits = logits.float() loss = criterion(logits, input_ids)
loss = criterion(logits, input_ids) model.backward(loss)
model.backward(loss)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format(
assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format( torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits)
torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits)
check_grad(model, torch_model)
check_grad(model, torch_model)
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port): config = {}
config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') exam_gpt_fwd_bwd()
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use()
@rerun_if_address_is_in_use() def test_gpt(world_size):
def test_gpt(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port())
run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
if __name__ == '__main__': test_gpt(1)
test_gpt(1)

View File

@ -1,118 +1,116 @@
import pytest from functools import partial
import colossalai from time import time
import torch
import torch.multiprocessing as mp import pytest
import torch.distributed as dist import torch
from colossalai.testing import rerun_if_address_is_in_use import torch.distributed as dist
from colossalai.utils.cuda import get_current_device import torch.multiprocessing as mp
from colossalai.utils import free_port from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.utils.model.colo_init_context import ColoInitContext
import colossalai
from functools import partial from colossalai.amp import convert_to_apex_amp
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.gemini.gemini_mgr import GeminiManager
from torch.nn.parallel import DistributedDataParallel as DDP from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.zero import ZeroOptimizer from colossalai.utils import free_port
from colossalai.testing import parameterize from colossalai.utils.cuda import get_current_device
from colossalai.amp import convert_to_apex_amp from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.zero import ZeroOptimizer
from tests.test_tensor.common_utils import debug_print from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal
from time import time
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
def check_param(model: ZeroDDP, torch_model: torch.nn.Module): torch_dict = torch_model.state_dict()
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict() for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
for key, value in torch_dict.items(): key = key[7:]
# key is 'module.model.PARAMETER', so we truncate it if key == 'model.lm_head.weight':
key = key[7:] continue
if key == 'model.lm_head.weight': assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
continue temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
optimizer.zero_grad()
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): logits = model(input_ids, attn_mask)
optimizer.zero_grad() logits = logits.float()
logits = model(input_ids, attn_mask) loss = criterion(logits, input_ids)
logits = logits.float() optimizer.backward(loss)
loss = criterion(logits, input_ids) return logits
optimizer.backward(loss)
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy):
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) set_seed(42)
def exam_gpt_fwd_bwd(placement_policy): get_components_func = non_distributed_component_funcs.get_callable('gpt2')
set_seed(42) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() with ColoInitContext(device=get_current_device()):
model = model_builder()
with ColoInitContext(device=get_current_device()):
model = model_builder() torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_model = model_builder().cuda() torch_p.data.copy_(p.data)
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data) world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
world_size = torch.distributed.get_world_size() config_dict[world_size]['chunk_size'] = 5000
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['keep_gathered'] = False
config_dict[world_size]['chunk_size'] = 5000 if placement_policy != 'cuda':
config_dict[world_size]['keep_gathered'] = False init_device = torch.device('cpu')
if placement_policy != 'cuda': else:
init_device = torch.device('cpu') init_device = None
else: chunk_manager = ChunkManager(config_dict, init_device=init_device)
init_device = None gemini_manager = GeminiManager(placement_policy, chunk_manager)
chunk_manager = ChunkManager(config_dict, init_device=init_device) model = ZeroDDP(model, gemini_manager, pin_memory=True)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) model.eval()
torch_model.eval()
model.eval()
torch_model.eval() set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
set_seed(dist.get_rank() * 3 + 128) if i > 2:
for i, (input_ids, attn_mask) in enumerate(train_dataloader): break
if i > 2:
break zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) # debug_print([0], zero_logits, torch_logits)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits) zero_optim.step()
torch_optim.step()
zero_optim.step()
torch_optim.step() check_param(model, torch_model)
check_param(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
config = {} exam_gpt_fwd_bwd()
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.dist @rerun_if_address_is_in_use()
@pytest.mark.parametrize('world_size', [1, 4]) def test_gpt(world_size):
@rerun_if_address_is_in_use() run_func = partial(run_dist, world_size=world_size, port=free_port())
def test_gpt(world_size): mp.spawn(run_func, nprocs=world_size)
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(1)
if __name__ == '__main__':
test_gpt(1)

View File

@ -1,66 +1,65 @@
import pytest from functools import partial
from functools import partial import pytest
import torch
import torch import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist
import colossalai
import colossalai from colossalai.gemini.chunk import search_chunk_configuration
from colossalai.testing import rerun_if_address_is_in_use from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.gemini.chunk import search_chunk_configuration from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup from tests.components_to_test.registry import non_distributed_component_funcs
from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model, pg: ProcessGroup):
def init_1d_row_spec(model, pg: ProcessGroup): tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) for n, p in model.named_parameters():
for n, p in model.named_parameters(): if 'weight' in n and 'ln' not in n:
if 'weight' in n and 'ln' not in n: p.set_process_group(pg)
p.set_process_group(pg) p.set_tensor_spec(*tensor_spec)
p.set_tensor_spec(*tensor_spec)
def exam_search_chunk_size():
def exam_search_chunk_size():
world_size = torch.distributed.get_world_size()
world_size = torch.distributed.get_world_size() pg_tp = ProcessGroup(tp_degree=world_size)
pg_tp = ProcessGroup(tp_degree=world_size)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# make sure torch_model and model has the same parameter values
# make sure torch_model and model has the same parameter values with ColoInitContext(device=get_current_device()):
with ColoInitContext(device=get_current_device()): model = model_builder()
model = model_builder() init_1d_row_spec(model, pg_tp)
init_1d_row_spec(model, pg_tp) config_dict, _ = search_chunk_configuration(model,
config_dict = search_chunk_configuration(model, search_range_mb=1,
search_range_mb=1, search_interval_byte=16,
search_interval_byte=16, min_chunk_size_mb=0,
min_chunk_size_mb=0, filter_exlarge_params=True)
filter_exlarge_params=True)
for key in config_dict:
for key in config_dict: chunk_size = config_dict[key]['chunk_size']
chunk_size = config_dict[key]['chunk_size'] if world_size == 1:
if world_size == 1: assert chunk_size == 31616
assert chunk_size == 31616 else:
else: assert chunk_size == 1024
assert chunk_size == 1024
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') exam_search_chunk_size()
exam_search_chunk_size()
@pytest.mark.dist
@pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use()
@rerun_if_address_is_in_use() def test_search(world_size):
def test_search(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port())
run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
if __name__ == '__main__': test_search(4)
test_search(4)

View File

@ -1,110 +1,108 @@
import pytest from functools import partial
import colossalai
import torch import pytest
import torch.multiprocessing as mp import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use import torch.multiprocessing as mp
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port import colossalai
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from functools import partial from colossalai.nn.parallel import ZeroDDP
from tests.test_tensor.common_utils import set_seed from colossalai.testing import parameterize, rerun_if_address_is_in_use
from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.utils import free_port
from colossalai.nn.parallel import ZeroDDP from colossalai.utils.cuda import get_current_device
from colossalai.testing import parameterize from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini.gemini_mgr import GeminiManager from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print from tests.test_tensor.common_utils import debug_print, set_seed
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) def exam_state_dict(placement_policy, keep_gathered):
@parameterize('keep_gathered', [True, False]) set_seed(431)
def exam_state_dict(placement_policy, keep_gathered): get_components_func = non_distributed_component_funcs.get_callable('gpt2')
set_seed(431) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() with ColoInitContext(device=get_current_device()):
model = model_builder()
with ColoInitContext(device=get_current_device()):
model = model_builder() torch_model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_model = model_builder() torch_p.data.copy_(p.data)
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data) world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
world_size = torch.distributed.get_world_size() config_dict[world_size]['chunk_size'] = 5000
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['keep_gathered'] = keep_gathered
config_dict[world_size]['chunk_size'] = 5000 chunk_manager = ChunkManager(config_dict)
config_dict[world_size]['keep_gathered'] = keep_gathered gemini_manager = GeminiManager(placement_policy, chunk_manager)
chunk_manager = ChunkManager(config_dict) model = ZeroDDP(model, gemini_manager, pin_memory=True)
gemini_manager = GeminiManager(placement_policy, chunk_manager) model.train()
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model.train() zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict() for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
for key, value in torch_dict.items(): continue
if key == 'model.lm_head.weight': assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
continue temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) def exam_load_state_dict(placement_policy, keep_gathered):
@parameterize('keep_gathered', [True, False]) set_seed(431)
def exam_load_state_dict(placement_policy, keep_gathered): get_components_func = non_distributed_component_funcs.get_callable('gpt2')
set_seed(431) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() with ColoInitContext(device=get_current_device()):
model = model_builder()
with ColoInitContext(device=get_current_device()):
model = model_builder() set_seed(451)
torch_model = model_builder() # get a different model
set_seed(451)
torch_model = model_builder() # get a different model world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
world_size = torch.distributed.get_world_size() config_dict[world_size]['chunk_size'] = 5000
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['keep_gathered'] = keep_gathered
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered if placement_policy != 'cuda':
init_device = torch.device('cpu')
if placement_policy != 'cuda': else:
init_device = torch.device('cpu') init_device = None
else: chunk_manager = ChunkManager(config_dict, init_device=init_device)
init_device = None gemini_manager = GeminiManager(placement_policy, chunk_manager)
chunk_manager = ChunkManager(config_dict, init_device=init_device) model = ZeroDDP(model, gemini_manager, pin_memory=True)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True) torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False)
torch_dict = torch_model.state_dict() zero_dict = model.state_dict(only_rank_0=False)
model.load_state_dict(torch_dict, strict=False)
zero_dict = model.state_dict(only_rank_0=False) for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
for key, value in torch_dict.items(): continue
if key == 'model.lm_head.weight': assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
continue temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
def run_dist(rank, world_size, port):
config = {}
def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
config = {} exam_state_dict()
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') exam_load_state_dict()
exam_state_dict()
exam_load_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.dist @rerun_if_address_is_in_use()
@pytest.mark.parametrize('world_size', [1, 4]) def test_zero_ddp(world_size):
@rerun_if_address_is_in_use() run_func = partial(run_dist, world_size=world_size, port=free_port())
def test_zero_ddp(world_size): mp.spawn(run_func, nprocs=world_size)
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_ddp(1)
if __name__ == '__main__':
test_zero_ddp(1)

View File

@ -1,97 +1,95 @@
import pytest from functools import partial
import colossalai
import torch import pytest
import torch.multiprocessing as mp import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use import torch.multiprocessing as mp
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port import colossalai
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from functools import partial from colossalai.nn.optimizer import HybridAdam
from tests.test_tensor.common_utils import set_seed from colossalai.nn.parallel import ZeroDDP
from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.nn.parallel import ZeroDDP from colossalai.utils import free_port
from colossalai.zero import ZeroOptimizer from colossalai.utils.cuda import get_current_device
from colossalai.nn.optimizer import HybridAdam from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.testing import parameterize from colossalai.zero import ZeroOptimizer
from colossalai.gemini.gemini_mgr import GeminiManager from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print from tests.test_tensor.common_utils import debug_print, set_seed
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) def exam_zero_optim_state_dict(placement_policy, keep_gathered):
@parameterize('keep_gathered', [True, False]) set_seed(431)
def exam_zero_optim_state_dict(placement_policy, keep_gathered): get_components_func = non_distributed_component_funcs.get_callable('gpt2')
set_seed(431) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() with ColoInitContext(device=get_current_device()):
model = model_builder()
with ColoInitContext(device=get_current_device()):
model = model_builder() set_seed(451)
torch_model = model_builder() # get a different model
set_seed(451)
torch_model = model_builder() # get a different model world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
world_size = torch.distributed.get_world_size() config_dict[world_size]['chunk_size'] = 5000
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['keep_gathered'] = keep_gathered
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered if placement_policy != 'cuda':
init_device = torch.device('cpu')
if placement_policy != 'cuda': else:
init_device = torch.device('cpu') init_device = None
else: chunk_manager = ChunkManager(config_dict, init_device=init_device)
init_device = None gemini_manager = GeminiManager(placement_policy, chunk_manager)
chunk_manager = ChunkManager(config_dict, init_device=init_device) model = ZeroDDP(model, gemini_manager, pin_memory=True)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True) optimizer = HybridAdam(model.parameters())
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
optimizer = HybridAdam(model.parameters())
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 set_seed(dist.get_rank() * 3 + 128)
model.train()
set_seed(dist.get_rank() * 3 + 128) for i, (input_ids, attn_mask) in enumerate(train_dataloader):
model.train() if i > 0:
for i, (input_ids, attn_mask) in enumerate(train_dataloader): break
if i > 0: optim.zero_grad()
break logits = model(input_ids, attn_mask)
optim.zero_grad() logits = logits.float()
logits = model(input_ids, attn_mask) loss = criterion(logits, input_ids)
logits = logits.float() optim.backward(loss)
loss = criterion(logits, input_ids) optim.step()
optim.backward(loss)
optim.step() optim_state_dict = optim.state_dict()
optim.load_state_dict(optim_state_dict)
optim_state_dict = optim.state_dict() new_state = optim.state_dict()['state']
optim.load_state_dict(optim_state_dict) org_state = optim_state_dict['state']
new_state = optim.state_dict()['state']
org_state = optim_state_dict['state'] for k, v in org_state.items():
w = new_state[k]
for k, v in org_state.items(): for n, m in v.items():
w = new_state[k] if isinstance(m, torch.Tensor):
for n, m in v.items(): o = w[n]
if isinstance(m, torch.Tensor): if m.device != o.device:
o = w[n] o = o.to(m.device)
if m.device != o.device: assert torch.equal(m, o)
o = o.to(m.device) else:
assert torch.equal(m, o) assert m == w[n]
else:
assert m == w[n]
def run_dist(rank, world_size, port):
config = {}
def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
config = {} exam_zero_optim_state_dict()
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_zero_optim_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.dist @rerun_if_address_is_in_use()
@pytest.mark.parametrize('world_size', [1, 4]) def test_zero_optim(world_size):
@rerun_if_address_is_in_use() run_func = partial(run_dist, world_size=world_size, port=free_port())
def test_zero_optim(world_size): mp.spawn(run_func, nprocs=world_size)
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_optim(1)
if __name__ == '__main__':
test_zero_optim(1)

View File

@ -1,23 +1,24 @@
from functools import partial
import pytest import pytest
import colossalai
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam import colossalai
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal
from tests.test_tensor.model.test_gpt2 import init_megatron_spec from tests.test_tensor.model.test_gpt2 import init_megatron_spec
@ -88,7 +89,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func(model, pg) tp_init_spec_func(model, pg)
dp_world_size = pg.dp_world_size() dp_world_size = pg.dp_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[dp_world_size]['chunk_size'] = 5000 config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda': if placement_policy != 'cuda':