mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
@@ -31,11 +31,10 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
||||
__all__ = ['TorchFSDPPlugin']
|
||||
__all__ = ["TorchFSDPPlugin"]
|
||||
|
||||
|
||||
class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.coordinator = DistCoordinator()
|
||||
@@ -69,26 +68,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||
|
||||
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
|
||||
size_per_shard: int, use_safetensors: bool):
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
prefix: Optional[str],
|
||||
size_per_shard: int,
|
||||
use_safetensors: bool,
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
|
||||
def load_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
def load_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
):
|
||||
"""
|
||||
Load model to checkpoint but only on master process.
|
||||
"""
|
||||
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
|
||||
|
||||
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
@@ -109,7 +118,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
|
||||
class TorchFSDPModel(ModelWrapper):
|
||||
|
||||
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
|
||||
super().__init__(module)
|
||||
self.module = FSDP(module, *args, **kwargs)
|
||||
@@ -119,7 +127,6 @@ class TorchFSDPModel(ModelWrapper):
|
||||
|
||||
|
||||
class FSDPOptimizerWrapper(OptimizerWrapper):
|
||||
|
||||
def __init__(self, optimizer: Optimizer, model: nn.Module):
|
||||
self.model = model
|
||||
super().__init__(optimizer)
|
||||
@@ -147,7 +154,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
See https://pytorch.org/docs/stable/fsdp.html for details.
|
||||
"""
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -162,15 +169,18 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
sync_module_states: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.fsdp_kwargs = dict(process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
sync_module_states=sync_module_states)
|
||||
self.fsdp_kwargs = dict(
|
||||
process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
backward_prefetch=backward_prefetch,
|
||||
mixed_precision=mixed_precision,
|
||||
ignored_modules=ignored_modules,
|
||||
param_init_fn=param_init_fn,
|
||||
sync_module_states=sync_module_states,
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||
|
||||
@@ -184,13 +194,13 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
return True
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16', 'bf16']
|
||||
return ["fp16", "bf16"]
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ['cuda']
|
||||
return ["cuda"]
|
||||
|
||||
def configure(
|
||||
self,
|
||||
@@ -200,14 +210,13 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
# wrap the model with PyTorch FSDP
|
||||
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
|
||||
|
||||
if optimizer is not None:
|
||||
if len(optimizer.param_groups) > 1:
|
||||
warnings.warn(
|
||||
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
|
||||
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
|
||||
)
|
||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||
|
||||
|
Reference in New Issue
Block a user