mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[hotfix]fix argument naming in docs and examples (#4083)
This commit is contained in:
@@ -34,9 +34,9 @@ class ColossalAIStrategy(DDPStrategy):
|
||||
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
|
||||
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
|
||||
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
|
||||
search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
|
||||
search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
|
||||
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
|
||||
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
|
||||
min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
|
||||
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
|
||||
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
|
||||
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
|
||||
@@ -61,9 +61,9 @@ class ColossalAIStrategy(DDPStrategy):
|
||||
placement_policy: str = 'cuda',
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
search_range_mb: int = 32, # only for stage 3
|
||||
search_range_m: int = 32, # only for stage 3
|
||||
hidden_dim: Optional[int] = None, # only for stage 3
|
||||
min_chunk_size_mb: float = 32, # only for stage 3
|
||||
min_chunk_size_m: float = 32, # only for stage 3
|
||||
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
||||
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
||||
overlap_communication: bool = True, # only for stage 1&2
|
||||
@@ -83,57 +83,51 @@ class ColossalAIStrategy(DDPStrategy):
|
||||
|
||||
# TODO(ver217): support shard_init when using from_pretrained()
|
||||
if shard_init:
|
||||
warnings.warn(
|
||||
f'Shard init is not supported model.from_pretrained() yet. '
|
||||
'Please load weights after strategy.prepare()'
|
||||
)
|
||||
warnings.warn(f'Shard init is not supported model.from_pretrained() yet. '
|
||||
'Please load weights after strategy.prepare()')
|
||||
if stage == 3 and precision == 'fp32':
|
||||
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
||||
precision = 'fp16'
|
||||
self.precision = precision
|
||||
self.shard_init = shard_init
|
||||
|
||||
optim_kwargs = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type
|
||||
)
|
||||
optim_kwargs = dict(initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type)
|
||||
# NOTE: dist should be initialized before calling get_current_device()
|
||||
if stage == 3:
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
# gemini_config
|
||||
# gemini_config
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
precision=precision,
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
strict_ddp_mode=shard_init,
|
||||
search_range_mb=search_range_mb,
|
||||
search_range_m=search_range_m,
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_mb=min_chunk_size_mb,
|
||||
# zero_optim_config
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
# zero_optim_config
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
# optim_config
|
||||
**optim_kwargs
|
||||
)
|
||||
# optim_config
|
||||
**optim_kwargs)
|
||||
else:
|
||||
plugin_initializer = lambda: LowLevelZeroPlugin(
|
||||
# zero_config
|
||||
# zero_config
|
||||
stage=stage,
|
||||
precision=precision,
|
||||
# zero_optim_config
|
||||
# zero_optim_config
|
||||
reduce_bucket_size_in_m=reduce_bucket_size,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=(placement_policy == 'cpu'),
|
||||
# optim_config
|
||||
**optim_kwargs
|
||||
)
|
||||
# optim_config
|
||||
**optim_kwargs)
|
||||
|
||||
super().__init__(seed, plugin_initializer)
|
||||
|
||||
|
Reference in New Issue
Block a user