[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,13 +1,13 @@
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
@@ -31,11 +31,10 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from .dp_plugin_base import DPPluginBase
__all__ = ['TorchFSDPPlugin']
__all__ = ["TorchFSDPPlugin"]
class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
@@ -69,26 +68,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
def save_sharded_model(
self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool,
prefix: Optional[str],
size_per_shard: int,
use_safetensors: bool,
):
"""
Save model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def load_sharded_model(self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
def load_sharded_model(
self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
Load model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
size_per_shard: int):
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save optimizer to checkpoint but only on master process.
"""
@@ -109,7 +118,6 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
class TorchFSDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = FSDP(module, *args, **kwargs)
@@ -119,7 +127,6 @@ class TorchFSDPModel(ModelWrapper):
class FSDPOptimizerWrapper(OptimizerWrapper):
def __init__(self, optimizer: Optimizer, model: nn.Module):
self.model = model
super().__init__(optimizer)
@@ -147,7 +154,7 @@ class TorchFSDPPlugin(DPPluginBase):
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
if version.parse(torch.__version__) >= version.parse('1.12.0'):
if version.parse(torch.__version__) >= version.parse("1.12.0"):
def __init__(
self,
@@ -162,15 +169,18 @@ class TorchFSDPPlugin(DPPluginBase):
sync_module_states: bool = False,
):
super().__init__()
self.fsdp_kwargs = dict(process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states)
self.fsdp_kwargs = dict(
process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@@ -184,13 +194,13 @@ class TorchFSDPPlugin(DPPluginBase):
return True
def supported_precisions(self) -> List[str]:
return ['fp16', 'bf16']
return ["fp16", "bf16"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
return ["cuda"]
def configure(
self,
@@ -200,14 +210,13 @@ class TorchFSDPPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(
'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.'
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)