mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -37,10 +37,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||
|
||||
|
||||
class HybridParallelModule(ModelWrapper):
|
||||
|
||||
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
|
||||
ddp_config: dict, custom_policy: Policy) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: Module,
|
||||
precision: str,
|
||||
shard_config: ShardConfig,
|
||||
dp_group: ProcessGroup,
|
||||
use_ddp: bool,
|
||||
ddp_config: dict,
|
||||
custom_policy: Policy,
|
||||
) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.dp_group = dp_group
|
||||
|
||||
@@ -54,13 +60,14 @@ class HybridParallelModule(ModelWrapper):
|
||||
for shared_param in self.shared_params:
|
||||
if len(shared_param) > 0:
|
||||
self.shared_param_process_groups.append(
|
||||
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
|
||||
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
|
||||
)
|
||||
|
||||
# setting mixed_precision
|
||||
self.mixed_precision = None
|
||||
if precision == 'fp16':
|
||||
if precision == "fp16":
|
||||
self.mixed_precision = torch.float16
|
||||
elif precision == 'bf16':
|
||||
elif precision == "bf16":
|
||||
self.mixed_precision = torch.bfloat16
|
||||
if self.mixed_precision is not None:
|
||||
module = module.to(self.mixed_precision)
|
||||
@@ -123,22 +130,21 @@ def get_param_info(optim: Optimizer):
|
||||
|
||||
if optim is None:
|
||||
return {}
|
||||
param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
|
||||
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
|
||||
start_index = 0
|
||||
for group in optim.param_groups:
|
||||
packed_group = {k: v for k, v in group.items() if k != "params"}
|
||||
packed_group["params"] = []
|
||||
|
||||
packed_group = {k: v for k, v in group.items() if k != 'params'}
|
||||
packed_group['params'] = []
|
||||
|
||||
for param_id, param in enumerate(group['params'], start_index):
|
||||
for param_id, param in enumerate(group["params"], start_index):
|
||||
original_shape = param.shape if isinstance(param, torch.Tensor) else None
|
||||
packed_group['params'].append(param_id)
|
||||
param_info['param2id'][id(param)] = param_id
|
||||
param_info['id2param'][param_id] = id(param)
|
||||
param_info['param2shape'][id(param)] = original_shape
|
||||
packed_group["params"].append(param_id)
|
||||
param_info["param2id"][id(param)] = param_id
|
||||
param_info["id2param"][param_id] = id(param)
|
||||
param_info["param2shape"][id(param)] = original_shape
|
||||
|
||||
param_info['param_groups'].append(packed_group)
|
||||
start_index += len(group['params'])
|
||||
param_info["param_groups"].append(packed_group)
|
||||
start_index += len(group["params"])
|
||||
|
||||
return param_info
|
||||
|
||||
@@ -147,13 +153,12 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
model_params = set(model.parameters())
|
||||
new_param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [p for p in group['params'] if p in model_params]
|
||||
new_param_groups.append({**group, 'params': params})
|
||||
optim.__setstate__({'param_groups': new_param_groups})
|
||||
params = [p for p in group["params"] if p in model_params]
|
||||
new_param_groups.append({**group, "params": params})
|
||||
optim.__setstate__({"param_groups": new_param_groups})
|
||||
|
||||
|
||||
class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
@@ -162,60 +167,87 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
|
||||
|
||||
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
|
||||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
precision: str = 'fp16',
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0):
|
||||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
precision: str = "fp16",
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
|
||||
hysteresis, max_scale, max_norm)
|
||||
super().__init__(
|
||||
optim,
|
||||
precision,
|
||||
initial_scale,
|
||||
min_scale,
|
||||
growth_factor,
|
||||
backoff_factor,
|
||||
growth_interval,
|
||||
hysteresis,
|
||||
max_scale,
|
||||
max_norm,
|
||||
)
|
||||
|
||||
|
||||
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.,
|
||||
backoff_factor: float = .5,
|
||||
growth_interval: int = 2000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: int = 2**24,
|
||||
clip_grad_norm: float = 0.0, # grad clipping
|
||||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.0,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 2000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: int = 2**24,
|
||||
clip_grad_norm: float = 0.0, # grad clipping
|
||||
verbose: bool = False,
|
||||
reduce_bucket_size: int = 1024 * 1024, # communication
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optimizer, model)
|
||||
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
|
||||
hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype,
|
||||
overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group,
|
||||
forced_dtype)
|
||||
super().__init__(
|
||||
optimizer,
|
||||
initial_scale,
|
||||
min_scale,
|
||||
growth_factor,
|
||||
backoff_factor,
|
||||
growth_interval,
|
||||
hysteresis,
|
||||
max_scale,
|
||||
clip_grad_norm,
|
||||
verbose,
|
||||
reduce_bucket_size,
|
||||
communication_dtype,
|
||||
overlap_communication,
|
||||
partition_grad,
|
||||
cpu_offload,
|
||||
dp_process_group,
|
||||
tp_process_group,
|
||||
forced_dtype,
|
||||
)
|
||||
|
||||
|
||||
class HybridParallelPlugin(PipelinePluginBase):
|
||||
@@ -276,46 +308,47 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
precision: str = 'fp16',
|
||||
zero_stage: int = 0,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
broadcast_buffers: bool = True,
|
||||
ddp_bucket_cap_mb: int = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
zero_bucket_size_in_m: int = 12,
|
||||
cpu_offload: bool = False,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
custom_policy: Policy = None) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
broadcast_buffers: bool = True,
|
||||
ddp_bucket_cap_mb: int = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
zero_bucket_size_in_m: int = 12,
|
||||
cpu_offload: bool = False,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
custom_policy: Policy = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dist.get_world_size() % (
|
||||
tp_size * pp_size
|
||||
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
|
||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
||||
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
@@ -334,24 +367,28 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.custom_policy = custom_policy
|
||||
assert zero_stage in (0, 1, 2)
|
||||
if self.pp_size > 1:
|
||||
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
|
||||
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
|
||||
assert (
|
||||
num_microbatches is not None or microbatch_size is not None
|
||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
|
||||
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
|
||||
num_microbatches=num_microbatches,
|
||||
microbatch_size=microbatch_size)
|
||||
self.schedule = OneForwardOneBackwardSchedule(
|
||||
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
|
||||
)
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
enable_all_optimization=self.enable_all_optimization,
|
||||
enable_fused_normalization=self.enable_fused_normalization,
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap)
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
pipeline_stage_manager=self.stage_manager,
|
||||
enable_tensor_parallelism=self.tp_size > 1,
|
||||
enable_all_optimization=self.enable_all_optimization,
|
||||
enable_fused_normalization=self.enable_fused_normalization,
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
@@ -362,18 +399,22 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
max_scale=max_scale,
|
||||
)
|
||||
|
||||
self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
|
||||
bucket_cap_mb=ddp_bucket_cap_mb,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
check_reduction=check_reduction,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph)
|
||||
self.ddp_config = dict(
|
||||
broadcast_buffers=broadcast_buffers,
|
||||
bucket_cap_mb=ddp_bucket_cap_mb,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
check_reduction=check_reduction,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph,
|
||||
)
|
||||
|
||||
self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(self.zero_stage == 2))
|
||||
self.zero_config = dict(
|
||||
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(self.zero_stage == 2),
|
||||
)
|
||||
|
||||
self.max_norm = max_norm
|
||||
|
||||
@@ -382,10 +423,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
return self.pp_size > 1
|
||||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ['cuda']
|
||||
return ["cuda"]
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16', 'bf16', 'fp32']
|
||||
return ["fp16", "bf16", "fp32"]
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
@@ -410,57 +451,67 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
param_info = get_param_info(optimizer)
|
||||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
|
||||
self.ddp_config, self.custom_policy)
|
||||
model = HybridParallelModule(
|
||||
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
if self.precision in ['fp16', 'bf16']:
|
||||
optimizer = HybridParallelAMPOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config)
|
||||
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
|
||||
optimizer.master_to_working_map)
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
**self.amp_config,
|
||||
)
|
||||
self.checkpoint_io.link_master_and_working_param(
|
||||
optimizer.working_to_master_map, optimizer.master_to_working_map
|
||||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info)
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||
)
|
||||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = HybridParallelZeroOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
**self.amp_config)
|
||||
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
|
||||
optimizer._param_store.master_to_working_param)
|
||||
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = HybridParallelZeroOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
**self.amp_config,
|
||||
)
|
||||
self.checkpoint_io.link_master_and_working_param(
|
||||
optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param
|
||||
)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def execute_pipeline(self,
|
||||
data_iter: Iterator,
|
||||
model: HybridParallelModule,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
|
||||
HybridParallelZeroOptimizer]] = None,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
|
||||
def execute_pipeline(
|
||||
self,
|
||||
data_iter: Iterator,
|
||||
model: HybridParallelModule,
|
||||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Optional[
|
||||
Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer]
|
||||
] = None,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
|
||||
# return loss or outputs if needed
|
||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
with ctx:
|
||||
outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss,
|
||||
return_outputs)
|
||||
outputs = self.schedule.forward_backward_step(
|
||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
model.sync_shared_params()
|
||||
if isinstance(optimizer, HybridParallelZeroOptimizer):
|
||||
optimizer.sync_grad()
|
||||
@@ -468,15 +519,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
model.sync_grads()
|
||||
return outputs
|
||||
|
||||
def prepare_dataloader(self,
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
seed=1024,
|
||||
drop_last=False,
|
||||
pin_memory=False,
|
||||
num_workers=0,
|
||||
**kwargs):
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
||||
@@ -499,10 +544,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
sampler = DistributedSampler(dataset,
|
||||
num_replicas=self.pg_mesh.size(DP_AXIS),
|
||||
rank=self.pg_mesh.coordinate(DP_AXIS),
|
||||
shuffle=shuffle)
|
||||
sampler = DistributedSampler(
|
||||
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
@@ -511,14 +555,16 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
|
Reference in New Issue
Block a user