[gemini] update ddp strict mode (#2518)

* [zero] add strict ddp mode for chunk init

* [gemini] update gpt example
This commit is contained in:
HELSON 2023-01-28 14:35:25 +08:00 committed by GitHub
parent 0af793836c
commit 707b11d4a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 133 additions and 54 deletions

View File

@ -2,6 +2,7 @@ import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
@ -13,8 +14,14 @@ def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) ->
""" """
Filter those parameters whose size is too large (more than 3x standard deviations) from others. Filter those parameters whose size is too large (more than 3x standard deviations) from others.
""" """
params_size = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)] agg_size_list = []
params_size_arr = np.array(params_size) for key in size_dict:
agg_size_list.extend(size_dict[key])
if len(agg_size_list) == 0:
return
params_size_arr = np.array(agg_size_list)
std = np.std(params_size_arr) std = np.std(params_size_arr)
mean = np.mean(params_size_arr) mean = np.mean(params_size_arr)
@ -38,7 +45,15 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc return left + acc
def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]: def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
if strict_ddp_flag:
return local_param.numel_global()
else:
return local_param.numel()
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree """classify_params_by_dp_degree
Classify the parameters by their dp degree Classify the parameters by their dp degree
@ -56,7 +71,10 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int
if is_ddp_ignored(param): if is_ddp_ignored(param):
continue continue
param_key = param.process_group.dp_world_size() if strict_ddp_flag:
param_key = dist.get_world_size()
else:
param_key = param.process_group.dp_world_size()
if param_key not in params_dict: if param_key not in params_dict:
params_dict[param_key] = [] params_dict[param_key] = []
@ -71,14 +89,18 @@ def search_chunk_configuration(
search_interval_byte: int, # hidden size is the best value for the interval search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32, min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True, filter_exlarge_params: bool = True,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int]: strict_ddp_flag: bool = False,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
"""search_chunk_configuration """search_chunk_configuration
Args: Args:
model (nn.Module): torch module model (nn.Module): torch module
search_range_mb (float): searching range in mega byte. search_range_mb (float): searching range in mega byte.
search_interval_byte (int): searching interval in byte. search_interval_byte (int): searching interval in byte.
min_chunk_size_mb (float, optional): the minimum size of a distributed chunk.
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True. filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
all parameters keep replicated in this mode.
Returns: Returns:
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte. Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
@ -96,17 +118,20 @@ def search_chunk_configuration(
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0 assert search_range_byte >= 0
params_dict = classify_params_by_dp_degree(param_order) params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
config_dict: Dict[int, Dict] = dict() config_dict: Dict[int, Dict] = dict()
total_param_size = 0
size_dict: Dict[int, List[int]] = dict() size_dict: Dict[int, List[int]] = dict()
for dp_degree in params_dict: for dp_degree in params_dict:
params_list = params_dict[dp_degree] params_list = params_dict[dp_degree]
size_list = [p.numel() for p in params_list] size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list]
group_acc_size = sum(size_list)
total_param_size += group_acc_size
# let small parameters keep gathered in CUDA all the time # let small parameters keep gathered in CUDA all the time
total_size = sum(size_list) if group_acc_size < min_chunk_size_byte:
if total_size < min_chunk_size_byte: config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
config_dict[dp_degree] = dict(chunk_size=total_size, keep_gathered=True)
else: else:
size_dict[dp_degree] = size_list size_dict[dp_degree] = size_list
@ -134,4 +159,4 @@ def search_chunk_configuration(
continue continue
config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False) config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False)
return config_dict, min_chunk_waste return config_dict, total_param_size, min_chunk_waste

View File

@ -19,38 +19,24 @@ def safe_div(a, b):
def init_chunk_manager(model: nn.Module, def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None, hidden_dim: Optional[int] = None,
search_range_mb: Optional[float] = None, **kwargs) -> ChunkManager:
min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
kwargs_dict = dict()
if hidden_dim: if hidden_dim:
search_interval_byte = hidden_dim search_interval_byte = hidden_dim
else: else:
search_interval_byte = 1024 # 1kb search_interval_byte = 1024 # defaults to 1kb
kwargs_dict["search_interval_byte"] = search_interval_byte kwargs["search_interval_byte"] = search_interval_byte
if search_range_mb:
kwargs_dict["search_range_mb"] = search_range_mb
if min_chunk_size_mb:
kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb
if filter_exlarge_params:
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
params_sizes = [p.numel() for p in model.parameters() if not is_ddp_ignored(p)]
total_size = sum(params_sizes) / 1024**2
dist.barrier() dist.barrier()
begin = time() begin = time()
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict) config_dict, total_size, wasted_size = search_chunk_configuration(model, **kwargs)
dist.barrier() dist.barrier()
end = time() end = time()
span_s = end - begin span_s = end - begin
wasted_size /= 1024**2 mb_size = 1024**2
total_size /= mb_size
wasted_size /= mb_size
if dist.get_rank() == 0: if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),

View File

@ -234,11 +234,14 @@ class ZeroDDP(ColoDDP):
for p in module.parameters(): for p in module.parameters():
param_order.append(p) param_order.append(p)
ddp_pg = ColoProcessGroup()
for p in param_order.generate(): for p in param_order.generate():
assert isinstance(p, ColoParameter) assert isinstance(p, ColoParameter)
if strict_ddp_mode and not p.is_replicate(): if strict_ddp_mode:
p.set_dist_spec(ReplicaSpec()) if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
if is_ddp_ignored(p): if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16) p.data = p.data.to(device=get_current_device(), dtype=torch.float16)

View File

@ -20,7 +20,7 @@ class GeminiDDP(ZeroDDP):
strict_ddp_mode: bool = False, strict_ddp_mode: bool = False,
search_range_mb: int = 32, search_range_mb: int = 32,
hidden_dim: Optional[int] = None, hidden_dim: Optional[int] = None,
min_chunk_size_mb: Optional[float] = None, min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None) -> None: memstats: Optional[MemStats] = None) -> None:
""" """
A torch.Module warpper using ZeRO-DP and Genimi. A torch.Module warpper using ZeRO-DP and Genimi.
@ -53,6 +53,7 @@ class GeminiDDP(ZeroDDP):
init_device=device, init_device=device,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
search_range_mb=search_range_mb, search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb) min_chunk_size_mb=min_chunk_size_mb,
strict_ddp_flag=strict_ddp_mode)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)

View File

@ -1,3 +1,4 @@
import math
from copy import copy from copy import copy
from functools import lru_cache from functools import lru_cache
from typing import Callable, Optional, Set from typing import Callable, Optional, Set
@ -303,6 +304,11 @@ class ColoTensor(torch.Tensor):
else: else:
return size_list[args[0]] return size_list[args[0]]
def numel_global(self):
"""Returns the number of elements in the tensor when it's replicated.
"""
return math.prod(self.size_global())
# Some API for dist spec check # Some API for dist spec check
def is_replicate(self): def is_replicate(self):

View File

@ -263,7 +263,7 @@ def main():
if args.distplan == "colossalai": if args.distplan == "colossalai":
# all param must use the same process group. # all param must use the same process group.
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
# build GPT model # build GPT model

View File

@ -35,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP: def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
chunk_config, _ = search_chunk_configuration(module, 4, 1024) chunk_config, *_ = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config) chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager) gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager) return ZeroDDP(module, gemini_manager)

View File

@ -58,7 +58,7 @@ def exam_gpt_fwd_bwd(placement_policy,
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)

View File

@ -62,7 +62,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
assert len(step_list) == 4 assert len(step_list) == 4
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)

View File

@ -58,7 +58,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda': if placement_policy != 'cuda':

View File

@ -57,7 +57,7 @@ def exam_inference(placement_policy, model_name: str):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda': if placement_policy != 'cuda':

View File

@ -63,7 +63,7 @@ def exam_model_step(placement_policy, model_name: str):
p.data.copy_(torch_p.data) p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda': if placement_policy != 'cuda':

View File

@ -6,7 +6,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.gemini.chunk import search_chunk_configuration from colossalai.gemini.chunk import init_chunk_manager, search_chunk_configuration
from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
@ -23,7 +23,6 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def exam_search_chunk_size(): def exam_search_chunk_size():
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg_tp = ProcessGroup(tp_degree=world_size) pg_tp = ProcessGroup(tp_degree=world_size)
@ -34,11 +33,11 @@ def exam_search_chunk_size():
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder() model = model_builder()
init_1d_row_spec(model, pg_tp) init_1d_row_spec(model, pg_tp)
config_dict, _ = search_chunk_configuration(model, config_dict, *_ = search_chunk_configuration(model,
search_range_mb=1, search_range_mb=1,
search_interval_byte=16, search_interval_byte=16,
min_chunk_size_mb=0, min_chunk_size_mb=0,
filter_exlarge_params=True) filter_exlarge_params=True)
for key in config_dict: for key in config_dict:
chunk_size = config_dict[key]['chunk_size'] chunk_size = config_dict[key]['chunk_size']
@ -48,9 +47,68 @@ def exam_search_chunk_size():
assert chunk_size == 1024 assert chunk_size == 1024
def exam_search_strict_ddp():
world_size = torch.distributed.get_world_size()
default_shard_pg = ProcessGroup(tp_degree=world_size)
default_shard_spec = ShardSpec([-1], [world_size])
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# get the chunk configuration over replicated models
with ColoInitContext(device=get_current_device()):
ddp_model = model_builder()
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
filter_exlarge_params=True,
strict_ddp_flag=False)
# get the chunk configuration over sharded ddp models
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder()
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
filter_exlarge_params=True,
strict_ddp_flag=True)
assert re_dict == sh_dict
for key in re_dict:
assert re_dict[key] == sh_dict[key]
assert re_total == sh_total
assert re_wasted == sh_wasted
def exam_chunk_manager():
world_size = torch.distributed.get_world_size()
default_shard_pg = ProcessGroup(tp_degree=world_size)
default_shard_spec = ShardSpec([-1], [world_size])
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder()
chunk_manager = init_chunk_manager(sharded_ddp_model,
get_current_device(),
hidden_dim=16,
search_range_mb=1,
min_chunk_size_mb=0,
filter_exlarge_params=True,
strict_ddp_flag=True)
config_dict = chunk_manager.dp_degree_chunk_size_dict
assert len(config_dict) == 1
assert config_dict[world_size] == 31616
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_search_chunk_size() exam_search_chunk_size()
exam_search_strict_ddp()
exam_chunk_manager()
@pytest.mark.dist @pytest.mark.dist

View File

@ -41,7 +41,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)
@ -73,7 +73,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
torch_model = model_builder() # get a different model torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered config_dict[world_size]['keep_gathered'] = keep_gathered

View File

@ -33,7 +33,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
torch_model = model_builder() # get a different model torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered config_dict[world_size]['keep_gathered'] = keep_gathered

View File

@ -85,7 +85,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func(model, pg) tp_init_spec_func(model, pg)
dp_world_size = pg.dp_world_size() dp_world_size = pg.dp_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[dp_world_size]['chunk_size'] = 5000 config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda': if placement_policy != 'cuda':