mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-05 03:26:48 +00:00
[lazy] support from_pretrained (#4801)
* [lazy] patch from pretrained * [lazy] fix from pretrained and add tests * [devops] update ci
This commit is contained in:
parent
64a08b2dc3
commit
4965c0dabd
3
.github/workflows/build_on_pr.yml
vendored
3
.github/workflows/build_on_pr.yml
vendored
@ -141,7 +141,7 @@ jobs:
|
|||||||
runs-on: [self-hosted, gpu]
|
runs-on: [self-hosted, gpu]
|
||||||
container:
|
container:
|
||||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
@ -214,6 +214,7 @@ jobs:
|
|||||||
NCCL_SHM_DISABLE: 1
|
NCCL_SHM_DISABLE: 1
|
||||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||||
TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt
|
TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt
|
||||||
|
LLAMA_PATH: /data/scratch/llama-tiny
|
||||||
|
|
||||||
- name: Store Testmon Cache
|
- name: Store Testmon Cache
|
||||||
run: |
|
run: |
|
||||||
|
3
.github/workflows/build_on_schedule.yml
vendored
3
.github/workflows/build_on_schedule.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
|||||||
runs-on: [self-hosted, 8-gpu]
|
runs-on: [self-hosted, 8-gpu]
|
||||||
container:
|
container:
|
||||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||||
timeout-minutes: 40
|
timeout-minutes: 40
|
||||||
steps:
|
steps:
|
||||||
- name: Check GPU Availability # ensure all GPUs have enough memory
|
- name: Check GPU Availability # ensure all GPUs have enough memory
|
||||||
@ -64,6 +64,7 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
DATA: /data/scratch/cifar-10
|
DATA: /data/scratch/cifar-10
|
||||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||||
|
LLAMA_PATH: /data/scratch/llama-tiny
|
||||||
|
|
||||||
- name: Notify Lark
|
- name: Notify Lark
|
||||||
id: message-preparation
|
id: message-preparation
|
||||||
|
@ -50,7 +50,7 @@ jobs:
|
|||||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||||
container:
|
container:
|
||||||
image: ${{ matrix.container }}
|
image: ${{ matrix.container }}
|
||||||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
steps:
|
steps:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
@ -92,3 +92,4 @@ jobs:
|
|||||||
DATA: /data/scratch/cifar-10
|
DATA: /data/scratch/cifar-10
|
||||||
NCCL_SHM_DISABLE: 1
|
NCCL_SHM_DISABLE: 1
|
||||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||||
|
LLAMA_PATH: /data/scratch/llama-tiny
|
||||||
|
@ -41,7 +41,7 @@ jobs:
|
|||||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||||
container:
|
container:
|
||||||
image: ${{ matrix.container }}
|
image: ${{ matrix.container }}
|
||||||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
|
||||||
@ -87,3 +87,4 @@ jobs:
|
|||||||
DATA: /data/scratch/cifar-10
|
DATA: /data/scratch/cifar-10
|
||||||
NCCL_SHM_DISABLE: 1
|
NCCL_SHM_DISABLE: 1
|
||||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||||
|
LLAMA_PATH: /data/scratch/llama-tiny
|
||||||
|
@ -38,7 +38,7 @@ jobs:
|
|||||||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||||
container:
|
container:
|
||||||
image: ${{ matrix.container }}
|
image: ${{ matrix.container }}
|
||||||
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
|
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
steps:
|
steps:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
@ -85,6 +85,7 @@ jobs:
|
|||||||
DATA: /data/scratch/cifar-10
|
DATA: /data/scratch/cifar-10
|
||||||
NCCL_SHM_DISABLE: 1
|
NCCL_SHM_DISABLE: 1
|
||||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||||
|
LLAMA_PATH: /data/scratch/llama-tiny
|
||||||
|
|
||||||
- name: Notify Lark
|
- name: Notify Lark
|
||||||
id: message-preparation
|
id: message-preparation
|
||||||
|
@ -8,6 +8,7 @@ from torch.optim import Optimizer
|
|||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
import colossalai.interface.pretrained as pretrained_utils
|
||||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
|
||||||
@ -131,6 +132,7 @@ class Booster:
|
|||||||
"""
|
"""
|
||||||
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
||||||
# TODO(FrankLeeeee): consider multi-dataloader case
|
# TODO(FrankLeeeee): consider multi-dataloader case
|
||||||
|
pretrained_path = pretrained_utils.get_pretrained_path(model)
|
||||||
# transform model for mixed precision
|
# transform model for mixed precision
|
||||||
if self.plugin:
|
if self.plugin:
|
||||||
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
|
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
|
||||||
@ -146,6 +148,12 @@ class Booster:
|
|||||||
# when mixed_precision is specified and the plugin is not given or does not control the precision
|
# when mixed_precision is specified and the plugin is not given or does not control the precision
|
||||||
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
||||||
|
|
||||||
|
if pretrained_path:
|
||||||
|
self.load_model(model, pretrained_path)
|
||||||
|
# clear pretrained path attr
|
||||||
|
orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model
|
||||||
|
pretrained_utils.set_pretrained_path(orig_model, None)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
||||||
|
16
colossalai/interface/pretrained.py
Normal file
16
colossalai/interface/pretrained.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_pretrained_path",
|
||||||
|
"set_pretrained_path",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretrained_path(model: Module) -> Optional[str]:
|
||||||
|
return getattr(model, "_pretrained", None)
|
||||||
|
|
||||||
|
|
||||||
|
def set_pretrained_path(model: Module, path: str) -> None:
|
||||||
|
setattr(model, "_pretrained", path)
|
@ -11,6 +11,7 @@ from torch.utils._pytree import tree_map
|
|||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
from .construction import ConstructorManager
|
from .construction import ConstructorManager
|
||||||
|
from .pretrained import PretrainedManager
|
||||||
|
|
||||||
import colossalai._analyzer._subclasses._meta_registration # noqa
|
import colossalai._analyzer._subclasses._meta_registration # noqa
|
||||||
|
|
||||||
@ -595,11 +596,13 @@ class LazyInitContext:
|
|||||||
)
|
)
|
||||||
|
|
||||||
ConstructorManager.apply(overrides)
|
ConstructorManager.apply(overrides)
|
||||||
|
PretrainedManager.inject()
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.tensor_cls.default_device = self.old_default_device
|
self.tensor_cls.default_device = self.old_default_device
|
||||||
LazyInitContext._replaced = False
|
LazyInitContext._replaced = False
|
||||||
ConstructorManager.clear()
|
ConstructorManager.clear()
|
||||||
|
PretrainedManager.recover()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
|
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
|
||||||
|
309
colossalai/lazy/pretrained.py
Normal file
309
colossalai/lazy/pretrained.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
import os
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from colossalai.interface import pretrained as pretrained_interface
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainedManager:
|
||||||
|
old_from_pretrained: Optional[Callable] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def inject() -> None:
|
||||||
|
try:
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
# recover bound method to plain function
|
||||||
|
PretrainedManager.old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||||
|
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def recover() -> None:
|
||||||
|
try:
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
# convert plain function to class method
|
||||||
|
PreTrainedModel.from_pretrained = classmethod(PretrainedManager.old_from_pretrained)
|
||||||
|
PretrainedManager.old_from_pretrained = None
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_from_pretrained(
|
||||||
|
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs
|
||||||
|
) -> Module:
|
||||||
|
from transformers import GenerationConfig
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.modeling_utils import (
|
||||||
|
ContextManagers,
|
||||||
|
_add_variant,
|
||||||
|
cached_file,
|
||||||
|
download_url,
|
||||||
|
has_file,
|
||||||
|
is_offline_mode,
|
||||||
|
is_remote_url,
|
||||||
|
no_init_weights,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
is_safetensors_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
_ = kwargs.pop("mirror", None)
|
||||||
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
|
_fast_init = kwargs.pop("_fast_init", True)
|
||||||
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
variant = kwargs.pop("variant", None)
|
||||||
|
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||||
|
|
||||||
|
if len(kwargs) > 0:
|
||||||
|
logger.warning(f"Below kwargs may be ignored: {list(kwargs.keys())}")
|
||||||
|
|
||||||
|
from_pt = True
|
||||||
|
|
||||||
|
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||||
|
if from_pipeline is not None:
|
||||||
|
user_agent["using_pipeline"] = from_pipeline
|
||||||
|
|
||||||
|
if is_offline_mode() and not local_files_only:
|
||||||
|
logger.info("Offline mode: forcing local_files_only=True")
|
||||||
|
local_files_only = True
|
||||||
|
|
||||||
|
# Load config if we don't provide a configuration
|
||||||
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config_path = config if config is not None else pretrained_model_name_or_path
|
||||||
|
config, model_kwargs = cls.config_class.from_pretrained(
|
||||||
|
config_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
return_unused_kwargs=True,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
_from_auto=from_auto_class,
|
||||||
|
_from_pipeline=from_pipeline,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_kwargs = kwargs
|
||||||
|
|
||||||
|
if commit_hash is None:
|
||||||
|
commit_hash = getattr(config, "_commit_hash", None)
|
||||||
|
|
||||||
|
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||||
|
# index of the files.
|
||||||
|
|
||||||
|
if pretrained_model_name_or_path is not None:
|
||||||
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
|
if is_local:
|
||||||
|
if use_safetensors is not False and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
|
||||||
|
):
|
||||||
|
# Load from a safetensors checkpoint
|
||||||
|
archive_file = os.path.join(
|
||||||
|
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
|
||||||
|
)
|
||||||
|
elif use_safetensors is not False and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
|
||||||
|
):
|
||||||
|
# Load from a sharded safetensors checkpoint
|
||||||
|
archive_file = os.path.join(
|
||||||
|
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||||
|
)
|
||||||
|
elif os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
|
||||||
|
):
|
||||||
|
# Load from a PyTorch checkpoint
|
||||||
|
archive_file = os.path.join(
|
||||||
|
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
|
||||||
|
)
|
||||||
|
elif os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
|
||||||
|
):
|
||||||
|
# Load from a sharded PyTorch checkpoint
|
||||||
|
archive_file = os.path.join(
|
||||||
|
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
|
||||||
|
f" {pretrained_model_name_or_path}."
|
||||||
|
)
|
||||||
|
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||||
|
archive_file = pretrained_model_name_or_path
|
||||||
|
is_local = True
|
||||||
|
elif is_remote_url(pretrained_model_name_or_path):
|
||||||
|
filename = pretrained_model_name_or_path
|
||||||
|
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
||||||
|
else:
|
||||||
|
# set correct filename
|
||||||
|
if use_safetensors is not False:
|
||||||
|
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
|
||||||
|
else:
|
||||||
|
filename = _add_variant(WEIGHTS_NAME, variant)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load from URL or cache if already cached
|
||||||
|
cached_file_kwargs = {
|
||||||
|
"cache_dir": cache_dir,
|
||||||
|
"force_download": force_download,
|
||||||
|
"proxies": proxies,
|
||||||
|
"resume_download": resume_download,
|
||||||
|
"local_files_only": local_files_only,
|
||||||
|
"use_auth_token": use_auth_token,
|
||||||
|
"user_agent": user_agent,
|
||||||
|
"revision": revision,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"_raise_exceptions_for_missing_entries": False,
|
||||||
|
"_commit_hash": commit_hash,
|
||||||
|
}
|
||||||
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
|
||||||
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
||||||
|
# result when internet is up, the repo and revision exist, but the file does not.
|
||||||
|
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
|
||||||
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
|
resolved_archive_file = cached_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
|
||||||
|
**cached_file_kwargs,
|
||||||
|
)
|
||||||
|
if resolved_archive_file is not None:
|
||||||
|
pass
|
||||||
|
elif use_safetensors:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
||||||
|
filename = _add_variant(WEIGHTS_NAME, variant)
|
||||||
|
resolved_archive_file = cached_file(
|
||||||
|
pretrained_model_name_or_path, filename, **cached_file_kwargs
|
||||||
|
)
|
||||||
|
if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
|
||||||
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
|
resolved_archive_file = cached_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
_add_variant(WEIGHTS_INDEX_NAME, variant),
|
||||||
|
**cached_file_kwargs,
|
||||||
|
)
|
||||||
|
if resolved_archive_file is not None:
|
||||||
|
pass
|
||||||
|
if resolved_archive_file is None:
|
||||||
|
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||||
|
# message.
|
||||||
|
has_file_kwargs = {
|
||||||
|
"revision": revision,
|
||||||
|
"proxies": proxies,
|
||||||
|
"use_auth_token": use_auth_token,
|
||||||
|
}
|
||||||
|
if variant is not None and has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
|
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
|
||||||
|
f" {variant}. Use `variant=None` to load this model from those weights."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
|
f" {_add_variant(WEIGHTS_NAME, variant)}"
|
||||||
|
)
|
||||||
|
except EnvironmentError:
|
||||||
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||||
|
# to the original exception.
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
# For any other exception, we throw a generic error.
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||||
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||||
|
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||||
|
f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_local:
|
||||||
|
logger.info(f"loading weights file {archive_file}")
|
||||||
|
resolved_archive_file = archive_file
|
||||||
|
else:
|
||||||
|
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
||||||
|
else:
|
||||||
|
resolved_archive_file = None
|
||||||
|
|
||||||
|
if from_pt:
|
||||||
|
# set dtype to instantiate the model under:
|
||||||
|
# 1. If torch_dtype is not None, we use that dtype
|
||||||
|
dtype_orig = None
|
||||||
|
|
||||||
|
if torch_dtype is not None:
|
||||||
|
if not isinstance(torch_dtype, torch.dtype):
|
||||||
|
raise ValueError(f"`torch_dtype` can be either `torch.dtype` or `None`, but received {torch_dtype}")
|
||||||
|
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||||
|
|
||||||
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
|
# Instantiate model.
|
||||||
|
init_contexts = [no_init_weights(_enable=_fast_init)]
|
||||||
|
|
||||||
|
with ContextManagers(init_contexts):
|
||||||
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
|
if from_pt:
|
||||||
|
# restore default dtype
|
||||||
|
if dtype_orig is not None:
|
||||||
|
torch.set_default_dtype(dtype_orig)
|
||||||
|
|
||||||
|
# make sure token embedding weights are still tied if needed
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# If it is a model with generation capabilities, attempt to load the generation config
|
||||||
|
if model.can_generate():
|
||||||
|
try:
|
||||||
|
model.generation_config = GenerationConfig.from_pretrained(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
_from_auto=from_auto_class,
|
||||||
|
_from_pipeline=from_pipeline,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except (OSError, TypeError):
|
||||||
|
logger.info("Generation config file not found, using a generation config created from the model config.")
|
||||||
|
|
||||||
|
# set pretrained path
|
||||||
|
if resolved_archive_file:
|
||||||
|
pretrained_interface.set_pretrained_path(model, resolved_archive_file)
|
||||||
|
|
||||||
|
return model
|
@ -3,11 +3,13 @@ import os
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from transformers import LlamaForCausalLM
|
||||||
from utils import shared_tempdir
|
from utils import shared_tempdir
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin
|
from colossalai.booster.plugin import GeminiPlugin
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.testing import (
|
from colossalai.testing import (
|
||||||
check_state_dict_equal,
|
check_state_dict_equal,
|
||||||
@ -120,11 +122,29 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
|||||||
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
|
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
|
||||||
|
|
||||||
|
|
||||||
|
def exam_lazy_from_pretrained():
|
||||||
|
llama_path = os.environ["LLAMA_PATH"]
|
||||||
|
plugin = GeminiPlugin()
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
orig_model = LlamaForCausalLM.from_pretrained(llama_path)
|
||||||
|
orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()}
|
||||||
|
with LazyInitContext():
|
||||||
|
model = LlamaForCausalLM.from_pretrained(llama_path)
|
||||||
|
model, *_ = booster.boost(model)
|
||||||
|
with shared_tempdir() as tempdir:
|
||||||
|
save_path = os.path.join(tempdir, "model.pt")
|
||||||
|
booster.save_model(model, save_path, shard=False)
|
||||||
|
dist.barrier()
|
||||||
|
state_dict = torch.load(save_path, map_location="cpu")
|
||||||
|
check_state_dict_equal(state_dict, orig_state_dict, False)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
config = {}
|
config = {}
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
exam_state_dict()
|
exam_state_dict()
|
||||||
exam_state_dict_with_origin()
|
exam_state_dict_with_origin()
|
||||||
|
exam_lazy_from_pretrained()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
31
tests/test_lazy/test_from_pretrained.py
Normal file
31
tests/test_lazy/test_from_pretrained.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from transformers import BertForPreTraining, LlamaForCausalLM
|
||||||
|
|
||||||
|
import colossalai.interface.pretrained as pretrained_utils
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
|
||||||
|
|
||||||
|
def test_lazy_from_pretrained():
|
||||||
|
# test from cached file, unsharded
|
||||||
|
model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny")
|
||||||
|
with LazyInitContext():
|
||||||
|
deffered_model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny")
|
||||||
|
pretrained_path = pretrained_utils.get_pretrained_path(deffered_model)
|
||||||
|
assert os.path.isfile(pretrained_path)
|
||||||
|
for p, lazy_p in zip(model.parameters(), deffered_model.parameters()):
|
||||||
|
assert p.shape == lazy_p.shape
|
||||||
|
|
||||||
|
# test from local file, sharded
|
||||||
|
llama_path = os.environ["LLAMA_PATH"]
|
||||||
|
model = LlamaForCausalLM.from_pretrained(llama_path)
|
||||||
|
with LazyInitContext():
|
||||||
|
deffered_model = LlamaForCausalLM.from_pretrained(llama_path)
|
||||||
|
pretrained_path = pretrained_utils.get_pretrained_path(deffered_model)
|
||||||
|
assert os.path.isfile(pretrained_path)
|
||||||
|
for p, lazy_p in zip(model.parameters(), deffered_model.parameters()):
|
||||||
|
assert p.shape == lazy_p.shape
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_lazy_from_pretrained()
|
Loading…
Reference in New Issue
Block a user