[hotfix]fix argument naming in docs and examples (#4083)

This commit is contained in:
Baizhou Zhang
2023-06-26 23:50:04 +08:00
committed by GitHub
parent e89b127d8e
commit 4da324cd60
8 changed files with 40 additions and 41 deletions

View File

@@ -34,9 +34,9 @@ class ColossalAIStrategy(DDPStrategy):
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
@@ -61,9 +61,9 @@ class ColossalAIStrategy(DDPStrategy):
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_mb: int = 32, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_mb: float = 32, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
@@ -83,57 +83,51 @@ class ColossalAIStrategy(DDPStrategy):
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()'
)
warnings.warn(f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()')
if stage == 3 and precision == 'fp32':
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
precision = 'fp16'
self.precision = precision
self.shard_init = shard_init
optim_kwargs = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
)
optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
# NOTE: dist should be initialized before calling get_current_device()
if stage == 3:
plugin_initializer = lambda: GeminiPlugin(
# gemini_config
# gemini_config
device=get_current_device(),
placement_policy=placement_policy,
precision=precision,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_mb=search_range_mb,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
# zero_optim_config
min_chunk_size_m=min_chunk_size_m,
# zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
# optim_config
**optim_kwargs
)
# optim_config
**optim_kwargs)
else:
plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
# zero_config
stage=stage,
precision=precision,
# zero_optim_config
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'),
# optim_config
**optim_kwargs
)
# optim_config
**optim_kwargs)
super().__init__(seed, plugin_initializer)