mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,14 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils._pytree import tree_map
|
||||
@@ -33,7 +31,7 @@ from colossalai.zero import LowLevelZeroOptimizer
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
||||
|
||||
__all__ = ['LowLevelZeroPlugin']
|
||||
__all__ = ["LowLevelZeroPlugin"]
|
||||
|
||||
|
||||
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||
@@ -42,17 +40,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||
return x
|
||||
|
||||
|
||||
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
|
||||
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
|
||||
def __init__(self, module: nn.Module, precision: str) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == 'fp16':
|
||||
if precision == "fp16":
|
||||
self.dtype = torch.float16
|
||||
elif precision == 'bf16':
|
||||
elif precision == "bf16":
|
||||
self.dtype = torch.bfloat16
|
||||
if self.dtype is not None:
|
||||
module = module.to(self.dtype)
|
||||
@@ -74,7 +71,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||
"""Save optimizer to checkpoint but only on master process.
|
||||
|
||||
@@ -91,12 +87,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def save_sharded_optimizer(self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = False,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024):
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = False,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
@@ -148,9 +146,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
if self.coordinator.is_master():
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}.")
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
||||
"""Load sharded optimizer with the given path to index file.
|
||||
@@ -170,8 +170,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
|
||||
Lacking param group file under current directory.')
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {index_file_path} for an optimizer. \
|
||||
Lacking param group file under current directory."
|
||||
)
|
||||
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
|
||||
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
@@ -181,9 +183,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
# shard state dict
|
||||
for param_idx, state in state_dict.items():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != 'step':
|
||||
padding_size = (self.coordinator.world_size -
|
||||
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
padding_size = (
|
||||
self.coordinator.world_size - v.numel() % self.coordinator.world_size
|
||||
) % self.coordinator.world_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
@@ -194,33 +197,39 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
|
||||
use_safetensors: bool):
|
||||
def save_unsharded_model(
|
||||
self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
|
||||
use_safetensors)
|
||||
super().save_sharded_model(
|
||||
model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
)
|
||||
|
||||
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().load_unsharded_model(model.module, checkpoint, strict)
|
||||
model.update_master_params()
|
||||
|
||||
def load_sharded_model(self,
|
||||
model: LowLevelZeroModel,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
def load_sharded_model(
|
||||
self,
|
||||
model: LowLevelZeroModel,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
@@ -264,7 +273,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
def __init__(
|
||||
self,
|
||||
stage: int = 1,
|
||||
precision: str = 'fp16',
|
||||
precision: str = "fp16",
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
@@ -281,9 +290,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
|
||||
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
|
||||
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
|
||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
|
||||
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
|
||||
self.stage = stage
|
||||
self.precision = precision
|
||||
self.zero_optim_kwargs = dict(
|
||||
@@ -319,7 +328,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
return True
|
||||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ['cuda']
|
||||
return ["cuda"]
|
||||
|
||||
def configure(
|
||||
self,
|
||||
@@ -329,15 +338,13 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.precision)
|
||||
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
|
||||
**self.zero_optim_kwargs,
|
||||
verbose=self.verbose)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
|
||||
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
|
||||
)
|
||||
# inject update_master_params
|
||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||
|
||||
|
Reference in New Issue
Block a user