mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
fixed mkdir conflict and align yapf config with flake (#220)
This commit is contained in:
@@ -2,6 +2,7 @@ import os
|
||||
import os.path as osp
|
||||
import re
|
||||
from typing import Tuple
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,10 +11,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
__all__ = [
|
||||
'get_checkpoint_path',
|
||||
'get_latest_checkpoint_path',
|
||||
'get_latest_checkpoint_pattern',
|
||||
'save_checkpoint',
|
||||
'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint',
|
||||
'load_checkpoint'
|
||||
]
|
||||
|
||||
@@ -70,9 +68,9 @@ def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
|
||||
|
||||
def _ensure_directory_exists(filename: str):
|
||||
# ensure the directory exists
|
||||
dir = os.path.dirname(filename)
|
||||
if not os.path.exists(dir):
|
||||
os.makedirs(dir)
|
||||
dirpath = os.path.dirname(filename)
|
||||
if not os.path.exists(dirpath):
|
||||
Path(dirpath).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_latest_checkpoint_pattern(suffix: str = ''):
|
||||
@@ -84,7 +82,8 @@ def get_latest_checkpoint_pattern(suffix: str = ''):
|
||||
:rtype: regular expression
|
||||
"""
|
||||
ranks_name = _get_ranks_name()
|
||||
ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt')
|
||||
pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix)
|
||||
ckpt_pattern = re.compile(pattern)
|
||||
return ckpt_pattern
|
||||
|
||||
|
||||
@@ -127,7 +126,8 @@ def save_checkpoint(checkpoint_path: str,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
**kwargs):
|
||||
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary.
|
||||
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model,
|
||||
optimizer, lr_scheduler and etc. into a checkpoint dictionary.
|
||||
|
||||
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
|
||||
|
||||
@@ -150,12 +150,7 @@ def save_checkpoint(checkpoint_path: str,
|
||||
model_sd = model.state_dict()
|
||||
|
||||
# ckpt container
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model': model_sd,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
**kwargs
|
||||
}
|
||||
checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs}
|
||||
if lr_scheduler is not None:
|
||||
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
||||
|
||||
@@ -171,9 +166,11 @@ def load_checkpoint(checkpoint_path: str,
|
||||
strict: bool = True) -> Tuple:
|
||||
"""Loads the checkpoint file.
|
||||
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
|
||||
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants.
|
||||
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
|
||||
and its descendants.
|
||||
If finetune is True, then only the weights and buffers of model should be reload.
|
||||
If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.
|
||||
If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s
|
||||
state_dict() function.
|
||||
|
||||
:param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
|
||||
:type checkpoint_path: str
|
||||
|
Reference in New Issue
Block a user