[lazy] support from_pretrained (#4801)

* [lazy] patch from pretrained

* [lazy] fix from pretrained and add tests

* [devops] update ci
This commit is contained in:
Hongxin Liu 2023-09-26 11:04:11 +08:00 committed by GitHub
parent 64a08b2dc3
commit 4965c0dabd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 397 additions and 5 deletions

View File

@ -141,7 +141,7 @@ jobs:
runs-on: [self-hosted, gpu] runs-on: [self-hosted, gpu]
container: container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 60 timeout-minutes: 60
defaults: defaults:
run: run:
@ -214,6 +214,7 @@ jobs:
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt
LLAMA_PATH: /data/scratch/llama-tiny
- name: Store Testmon Cache - name: Store Testmon Cache
run: | run: |

View File

@ -13,7 +13,7 @@ jobs:
runs-on: [self-hosted, 8-gpu] runs-on: [self-hosted, 8-gpu]
container: container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 40 timeout-minutes: 40
steps: steps:
- name: Check GPU Availability # ensure all GPUs have enough memory - name: Check GPU Availability # ensure all GPUs have enough memory
@ -64,6 +64,7 @@ jobs:
env: env:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
- name: Notify Lark - name: Notify Lark
id: message-preparation id: message-preparation

View File

@ -50,7 +50,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container: container:
image: ${{ matrix.container }} image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
- name: Install dependencies - name: Install dependencies
@ -92,3 +92,4 @@ jobs:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny

View File

@ -41,7 +41,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container: container:
image: ${{ matrix.container }} image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120 timeout-minutes: 120
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
@ -87,3 +87,4 @@ jobs:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny

View File

@ -38,7 +38,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container: container:
image: ${{ matrix.container }} image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
- name: Install dependencies - name: Install dependencies
@ -85,6 +85,7 @@ jobs:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
- name: Notify Lark - name: Notify Lark
id: message-preparation id: message-preparation

View File

@ -8,6 +8,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -131,6 +132,7 @@ class Booster:
""" """
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case # TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(FrankLeeeee): consider multi-dataloader case # TODO(FrankLeeeee): consider multi-dataloader case
pretrained_path = pretrained_utils.get_pretrained_path(model)
# transform model for mixed precision # transform model for mixed precision
if self.plugin: if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
@ -146,6 +148,12 @@ class Booster:
# when mixed_precision is specified and the plugin is not given or does not control the precision # when mixed_precision is specified and the plugin is not given or does not control the precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
if pretrained_path:
self.load_model(model, pretrained_path)
# clear pretrained path attr
orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model
pretrained_utils.set_pretrained_path(orig_model, None)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:

View File

@ -0,0 +1,16 @@
from typing import Optional
from torch.nn import Module
__all__ = [
"get_pretrained_path",
"set_pretrained_path",
]
def get_pretrained_path(model: Module) -> Optional[str]:
return getattr(model, "_pretrained", None)
def set_pretrained_path(model: Module, path: str) -> None:
setattr(model, "_pretrained", path)

View File

@ -11,6 +11,7 @@ from torch.utils._pytree import tree_map
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from .construction import ConstructorManager from .construction import ConstructorManager
from .pretrained import PretrainedManager
import colossalai._analyzer._subclasses._meta_registration # noqa import colossalai._analyzer._subclasses._meta_registration # noqa
@ -595,11 +596,13 @@ class LazyInitContext:
) )
ConstructorManager.apply(overrides) ConstructorManager.apply(overrides)
PretrainedManager.inject()
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.tensor_cls.default_device = self.old_default_device self.tensor_cls.default_device = self.old_default_device
LazyInitContext._replaced = False LazyInitContext._replaced = False
ConstructorManager.clear() ConstructorManager.clear()
PretrainedManager.recover()
@staticmethod @staticmethod
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:

View File

@ -0,0 +1,309 @@
import os
from typing import Callable, Optional, Union
import torch
from torch.nn import Module
from colossalai.interface import pretrained as pretrained_interface
class PretrainedManager:
old_from_pretrained: Optional[Callable] = None
@staticmethod
def inject() -> None:
try:
from transformers.modeling_utils import PreTrainedModel
except ImportError:
return
# recover bound method to plain function
PretrainedManager.old_from_pretrained = PreTrainedModel.from_pretrained.__func__
PreTrainedModel.from_pretrained = new_from_pretrained
@staticmethod
def recover() -> None:
try:
from transformers.modeling_utils import PreTrainedModel
except ImportError:
return
# convert plain function to class method
PreTrainedModel.from_pretrained = classmethod(PretrainedManager.old_from_pretrained)
PretrainedManager.old_from_pretrained = None
@classmethod
def new_from_pretrained(
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs
) -> Module:
from transformers import GenerationConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import (
ContextManagers,
_add_variant,
cached_file,
download_url,
has_file,
is_offline_mode,
is_remote_url,
no_init_weights,
)
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_safetensors_available,
logging,
)
logger = logging.get_logger(__name__)
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
if len(kwargs) > 0:
logger.warning(f"Below kwargs may be ignored: {list(kwargs.keys())}")
from_pt = True
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
else:
model_kwargs = kwargs
if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
if use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
):
# Load from a safetensors checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
):
# Load from a PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
)
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
else:
raise EnvironmentError(
f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
f" {pretrained_model_name_or_path}."
)
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
filename = pretrained_model_name_or_path
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename
if use_safetensors is not False:
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
else:
filename = _add_variant(WEIGHTS_NAME, variant)
try:
# Load from URL or cache if already cached
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
if resolved_archive_file is not None:
pass
elif use_safetensors:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
)
else:
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant(WEIGHTS_NAME, variant)
resolved_archive_file = cached_file(
pretrained_model_name_or_path, filename, **cached_file_kwargs
)
if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
if resolved_archive_file is not None:
pass
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if variant is not None and has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
f" {variant}. Use `variant=None` to load this model from those weights."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {_add_variant(WEIGHTS_NAME, variant)}"
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}."
)
if is_local:
logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else:
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
if from_pt:
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
dtype_orig = None
if torch_dtype is not None:
if not isinstance(torch_dtype, torch.dtype):
raise ValueError(f"`torch_dtype` can be either `torch.dtype` or `None`, but received {torch_dtype}")
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
if from_pt:
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
# make sure token embedding weights are still tied if needed
model.tie_weights()
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
except (OSError, TypeError):
logger.info("Generation config file not found, using a generation config created from the model config.")
# set pretrained path
if resolved_archive_file:
pretrained_interface.set_pretrained_path(model, resolved_archive_file)
return model

View File

@ -3,11 +3,13 @@ import os
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from transformers import LlamaForCausalLM
from utils import shared_tempdir from utils import shared_tempdir
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin from colossalai.booster.plugin import GeminiPlugin
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import ( from colossalai.testing import (
check_state_dict_equal, check_state_dict_equal,
@ -120,11 +122,29 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
def exam_lazy_from_pretrained():
llama_path = os.environ["LLAMA_PATH"]
plugin = GeminiPlugin()
booster = Booster(plugin=plugin)
orig_model = LlamaForCausalLM.from_pretrained(llama_path)
orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()}
with LazyInitContext():
model = LlamaForCausalLM.from_pretrained(llama_path)
model, *_ = booster.boost(model)
with shared_tempdir() as tempdir:
save_path = os.path.join(tempdir, "model.pt")
booster.save_model(model, save_path, shard=False)
dist.barrier()
state_dict = torch.load(save_path, map_location="cpu")
check_state_dict_equal(state_dict, orig_state_dict, False)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict() exam_state_dict()
exam_state_dict_with_origin() exam_state_dict_with_origin()
exam_lazy_from_pretrained()
@pytest.mark.dist @pytest.mark.dist

View File

@ -0,0 +1,31 @@
import os
from transformers import BertForPreTraining, LlamaForCausalLM
import colossalai.interface.pretrained as pretrained_utils
from colossalai.lazy import LazyInitContext
def test_lazy_from_pretrained():
# test from cached file, unsharded
model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny")
with LazyInitContext():
deffered_model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny")
pretrained_path = pretrained_utils.get_pretrained_path(deffered_model)
assert os.path.isfile(pretrained_path)
for p, lazy_p in zip(model.parameters(), deffered_model.parameters()):
assert p.shape == lazy_p.shape
# test from local file, sharded
llama_path = os.environ["LLAMA_PATH"]
model = LlamaForCausalLM.from_pretrained(llama_path)
with LazyInitContext():
deffered_model = LlamaForCausalLM.from_pretrained(llama_path)
pretrained_path = pretrained_utils.get_pretrained_path(deffered_model)
assert os.path.isfile(pretrained_path)
for p, lazy_p in zip(model.parameters(), deffered_model.parameters()):
assert p.shape == lazy_p.shape
if __name__ == "__main__":
test_lazy_from_pretrained()