mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-22 07:14:09 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			212 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			212 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | ||
| import os.path as osp
 | ||
| import re
 | ||
| from typing import Tuple
 | ||
| from pathlib import Path
 | ||
| 
 | ||
| import torch
 | ||
| 
 | ||
| from colossalai.context import Config
 | ||
| from colossalai.context.parallel_mode import ParallelMode
 | ||
| from colossalai.core import global_context as gpc
 | ||
| 
 | ||
| __all__ = [
 | ||
|     'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint',
 | ||
|     'load_checkpoint'
 | ||
| ]
 | ||
| 
 | ||
| 
 | ||
| def unwrap_config(config: Config):
 | ||
|     """Unwrap Config objects to normal dicts
 | ||
|     """
 | ||
|     config_dict = dict()
 | ||
|     for k, v in config.items():
 | ||
|         if isinstance(v, dict):
 | ||
|             config_dict[k] = unwrap_config(v)
 | ||
|         else:
 | ||
|             config_dict[k] = v
 | ||
| 
 | ||
|     return config_dict
 | ||
| 
 | ||
| 
 | ||
| def _get_ranks_name():
 | ||
|     # tensor parallel
 | ||
|     tp_local_rank = 0
 | ||
|     if gpc.is_initialized(ParallelMode.TENSOR):
 | ||
|         tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
 | ||
| 
 | ||
|     # pipeline parallel
 | ||
|     pp_local_rank = 0
 | ||
|     if gpc.is_initialized(ParallelMode.PIPELINE):
 | ||
|         pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
 | ||
| 
 | ||
|     ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}'
 | ||
|     return ranks_name
 | ||
| 
 | ||
| 
 | ||
| def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
 | ||
|     ranks_name = _get_ranks_name()
 | ||
|     return f'epoch{epoch}-{ranks_name}{suffix}.pt'
 | ||
| 
 | ||
| 
 | ||
| def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
 | ||
|     """This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
 | ||
|     This is useful during generation and recuperation of the checkpoint.
 | ||
| 
 | ||
|     :param checkpoint_dir: Set up a directory for saving checkpoints
 | ||
|     :type checkpoint_dir: str
 | ||
|     :param epoch: Epoch number (indicate how many epochs have you trained this model)
 | ||
|     :type epoch: int
 | ||
|     :param suffix: Additional notation to specify the model or checkpoint, defaults to ''
 | ||
|     :type suffix: str, optional
 | ||
|     :return: Checkpoint path to be generated
 | ||
|     :rtype: path
 | ||
|     """
 | ||
|     ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
 | ||
|     return os.path.join(checkpoint_dir, ckpt_filename)
 | ||
| 
 | ||
| 
 | ||
| def _ensure_directory_exists(filename: str):
 | ||
|     # ensure the directory exists
 | ||
|     dirpath = os.path.dirname(filename)
 | ||
|     if not os.path.exists(dirpath):
 | ||
|         Path(dirpath).mkdir(parents=True, exist_ok=True)
 | ||
| 
 | ||
| 
 | ||
| def get_latest_checkpoint_pattern(suffix: str = ''):
 | ||
|     """Generate Regular expression of latest checkpoint's pattern
 | ||
| 
 | ||
|     :param suffix: Additional notation to specify the model or checkpoint, defaults to ''
 | ||
|     :type suffix: str, optional
 | ||
|     :return: Checkpoint pattern
 | ||
|     :rtype: regular expression
 | ||
|     """
 | ||
|     ranks_name = _get_ranks_name()
 | ||
|     pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix)
 | ||
|     ckpt_pattern = re.compile(pattern)
 | ||
|     return ckpt_pattern
 | ||
| 
 | ||
| 
 | ||
| def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
 | ||
|     """This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
 | ||
|     This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
 | ||
| 
 | ||
|     :param checkpoint_dir: Directory for saving checkpoints
 | ||
|     :type checkpoint_dir: str
 | ||
|     :param suffix: Additional notation to specify the model or checkpoint, defaults to ''
 | ||
|     :type suffix: str, optional
 | ||
|     :raises FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given
 | ||
|     :return: The latest checkpoint path to be retrieved
 | ||
|     :rtype: path
 | ||
|     """
 | ||
|     CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
 | ||
| 
 | ||
|     last_epoch = -1
 | ||
|     assert osp.isdir(checkpoint_dir), f'{checkpoint_dir} is not a directory'
 | ||
| 
 | ||
|     for filename in os.listdir(checkpoint_dir):
 | ||
|         ret = CKPT_NAME_PAT.match(filename)
 | ||
|         if ret:
 | ||
|             epoch = int(ret[0].split('-')[0].lstrip('epoch'))
 | ||
|             if epoch > last_epoch:
 | ||
|                 last_epoch = epoch
 | ||
| 
 | ||
|     if last_epoch == -1:
 | ||
|         ranks_name = _get_ranks_name()
 | ||
|         raise FileNotFoundError(f"Cannot find the latest checkpoint file for {ranks_name} in {checkpoint_dir}")
 | ||
|     else:
 | ||
|         target_file = _get_standard_checkpoint_filename(last_epoch, suffix=suffix)
 | ||
|         path = osp.join(checkpoint_dir, target_file)
 | ||
|         return path
 | ||
| 
 | ||
| 
 | ||
| def save_checkpoint(checkpoint_path: str,
 | ||
|                     epoch: int,
 | ||
|                     model: torch.nn.Module,
 | ||
|                     optimizer: torch.optim.Optimizer,
 | ||
|                     lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
 | ||
|                     **kwargs):
 | ||
|     """Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model,
 | ||
|      optimizer, lr_scheduler and etc. into a checkpoint dictionary.
 | ||
| 
 | ||
|     This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
 | ||
| 
 | ||
| 
 | ||
|     :param checkpoint_path: Set up a directory for saving checkpoints
 | ||
|     :type checkpoint_path: str
 | ||
|     :param epoch: Epoch number (indicate how many epochs have you trained this model)
 | ||
|     :type epoch: int
 | ||
|     :param model: Model to be registered
 | ||
|     :type model: torch.nn.Module
 | ||
|     :param optimizer: Optimizer to be registered
 | ||
|     :type optimizer: torch.optim.Optimizer
 | ||
|     :param lr_scheduler: lr_scheduler to be registered, defaults to None
 | ||
|     :type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
 | ||
|     """
 | ||
|     # for compatibility with normal pytorch nn.Module
 | ||
|     if hasattr(model, 'state_dict_for_save_checkpoint'):
 | ||
|         model_sd = model.state_dict_for_save_checkpoint()
 | ||
|     else:
 | ||
|         model_sd = model.state_dict()
 | ||
| 
 | ||
|     # ckpt container
 | ||
|     checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs}
 | ||
|     if lr_scheduler is not None:
 | ||
|         checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
 | ||
| 
 | ||
|     _ensure_directory_exists(checkpoint_path)
 | ||
|     torch.save(checkpoint, checkpoint_path)
 | ||
| 
 | ||
| 
 | ||
| def load_checkpoint(checkpoint_path: str,
 | ||
|                     model: torch.nn.Module,
 | ||
|                     optimizer: torch.optim.Optimizer,
 | ||
|                     lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
 | ||
|                     finetune: bool = False,
 | ||
|                     strict: bool = True) -> Tuple:
 | ||
|     """Loads the checkpoint file.
 | ||
|     If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
 | ||
|     So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler)
 | ||
|      and its descendants.
 | ||
|     If finetune is True, then only the weights and buffers of model should be reload.
 | ||
|     If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s
 | ||
|      state_dict() function.
 | ||
| 
 | ||
|     :param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
 | ||
|     :type checkpoint_path: str
 | ||
|     :param model: Model to reload parameters and buffers
 | ||
|     :type model: torch.nn.Module
 | ||
|     :param optimizer: Optimizer to recuperate
 | ||
|     :type optimizer: torch.optim.Optimizer
 | ||
|     :param lr_scheduler: lr_scheduler to recuperate, defaults to None
 | ||
|     :type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
 | ||
|     :param finetune: Whether to finetune the model with new dataset or continue the pre-training, defaults to False
 | ||
|     :type finetune: bool, optional
 | ||
|     :param strict: Whether to strictly enforce that the keys in
 | ||
|         :attr:`state_dict` of the checkpoint match the names of
 | ||
|         parameters and buffers in model., defaults to True
 | ||
|     :type strict: bool, optional
 | ||
|     :raises ValueError: Raise error if the model/optimizer cannot successfully be recuperated
 | ||
|     :return: (the epoch number of the checkpoint retrieved, the checkpoint retrieved)
 | ||
|     :rtype: Tuple
 | ||
| 
 | ||
|     """
 | ||
|     # Load the checkpoint.
 | ||
|     checkpoint = torch.load(checkpoint_path, map_location='cpu')
 | ||
|     try:
 | ||
|         last_epoch = checkpoint.pop('epoch') if not finetune else 0
 | ||
|         model.load_state_dict(checkpoint.pop('model'), strict=strict)
 | ||
|     except KeyError:
 | ||
|         raise ValueError('Checkpoint is corrupted')
 | ||
| 
 | ||
|     if not finetune:
 | ||
|         try:
 | ||
|             optimizer.load_state_dict(checkpoint.pop('optimizer'))
 | ||
|         except KeyError:
 | ||
|             raise ValueError('Checkpoint is corrupted')
 | ||
| 
 | ||
|         if lr_scheduler is not None and 'lr_scheduler' in checkpoint:
 | ||
|             lr_scheduler.load_state_dict(checkpoint.pop('lr_scheduler'))
 | ||
| 
 | ||
|     return last_epoch, checkpoint
 |