[gemini] fix argument naming during chunk configuration searching

This commit is contained in:
Baizhou Zhang
2023-06-25 13:34:15 +08:00
parent b463651f3e
commit 0bb0b481b4
17 changed files with 62 additions and 64 deletions

View File

@@ -114,9 +114,9 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
def search_chunk_configuration(
model: nn.Module,
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
search_range_m: float,
search_interval: int, # hidden size is the best value for the interval
min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
@@ -126,9 +126,9 @@ def search_chunk_configuration(
Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.
search_interval_byte (int): searching interval in byte.
min_chunk_size_mb (float, optional): the minimum size of a distributed chunk.
search_range_m (float): searching range divided by 2^20.
search_interval (int): searching interval.
min_chunk_size_m (float, optional): the minimum size of a distributed chunk, divided by 2^20..
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
all parameters keep replicated in this mode.
@@ -145,9 +145,9 @@ def search_chunk_configuration(
for p in model.parameters():
param_order.append(p)
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
search_range = round(search_range_m * 1024**2)
min_chunk_size = round(min_chunk_size_m * 1024**2)
assert search_range >= 0
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
size_lcm = np.lcm.reduce(list(params_dict.keys()))
@@ -162,7 +162,7 @@ def search_chunk_configuration(
total_param_size += group_acc_size
# let small parameters keep gathered in CUDA all the time
if group_acc_size < min_chunk_size_byte:
if group_acc_size < min_chunk_size:
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
else:
size_dict[dp_degree] = size_list
@@ -170,15 +170,15 @@ def search_chunk_configuration(
if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)
max_size = min_chunk_size_byte
max_size = min_chunk_size
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
start_size = int(math.ceil(max_size / search_interval) * search_interval)
min_chunk_waste = float('+inf')
best_chunk_size = start_size
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
for chunk_size in range(start_size, start_size + search_range + 1, search_interval):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)

View File

@@ -23,10 +23,10 @@ def init_chunk_manager(model: nn.Module,
verbose: bool = False,
**kwargs) -> ChunkManager:
if hidden_dim:
search_interval_byte = hidden_dim
search_interval = hidden_dim
else:
search_interval_byte = 1024 # defaults to 1kb
kwargs["search_interval_byte"] = search_interval_byte
search_interval = 1024 # defaults to 1024
kwargs["search_interval"] = search_interval
dist.barrier()
begin = time()
@@ -36,13 +36,13 @@ def init_chunk_manager(model: nn.Module,
dist.barrier()
end = time()
span_s = end - begin
mb_size = 1024**2
total_size /= mb_size
wasted_size /= mb_size
mega_unit = 1024**2
total_size /= mega_unit
wasted_size /= mega_unit
if verbose and dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='',
flush=True)