diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 291d6adac..e2114d43b 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -141,7 +141,7 @@ jobs: runs-on: [self-hosted, gpu] container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 60 defaults: run: @@ -214,6 +214,7 @@ jobs: NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt + LLAMA_PATH: /data/scratch/llama-tiny - name: Store Testmon Cache run: | diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index 03b47e6cb..6c77377be 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -13,7 +13,7 @@ jobs: runs-on: [self-hosted, 8-gpu] container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 40 steps: - name: Check GPU Availability # ensure all GPUs have enough memory @@ -64,6 +64,7 @@ jobs: env: DATA: /data/scratch/cifar-10 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny - name: Notify Lark id: message-preparation diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml index 2f03c8ced..508321299 100644 --- a/.github/workflows/compatiblity_test_on_dispatch.yml +++ b/.github/workflows/compatiblity_test_on_dispatch.yml @@ -50,7 +50,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 steps: - name: Install dependencies @@ -92,3 +92,4 @@ jobs: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml index a621c7e34..cc17c66f9 100644 --- a/.github/workflows/compatiblity_test_on_pr.yml +++ b/.github/workflows/compatiblity_test_on_pr.yml @@ -41,7 +41,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }} @@ -87,3 +87,4 @@ jobs: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml index 9933224f5..158fe751b 100644 --- a/.github/workflows/compatiblity_test_on_schedule.yml +++ b/.github/workflows/compatiblity_test_on_schedule.yml @@ -38,7 +38,7 @@ jobs: matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: image: ${{ matrix.container }} - options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 + options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 120 steps: - name: Install dependencies @@ -85,6 +85,7 @@ jobs: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 + LLAMA_PATH: /data/scratch/llama-tiny - name: Notify Lark id: message-preparation diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 8d6b0b42e..d73bc5bab 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -8,6 +8,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +import colossalai.interface.pretrained as pretrained_utils from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -131,6 +132,7 @@ class Booster: """ # TODO(FrankLeeeee): consider multi-model and multi-optimizer case # TODO(FrankLeeeee): consider multi-dataloader case + pretrained_path = pretrained_utils.get_pretrained_path(model) # transform model for mixed precision if self.plugin: model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( @@ -146,6 +148,12 @@ class Booster: # when mixed_precision is specified and the plugin is not given or does not control the precision model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) + if pretrained_path: + self.load_model(model, pretrained_path) + # clear pretrained path attr + orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model + pretrained_utils.set_pretrained_path(orig_model, None) + return model, optimizer, criterion, dataloader, lr_scheduler def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: diff --git a/colossalai/interface/pretrained.py b/colossalai/interface/pretrained.py new file mode 100644 index 000000000..2f6bc10cd --- /dev/null +++ b/colossalai/interface/pretrained.py @@ -0,0 +1,16 @@ +from typing import Optional + +from torch.nn import Module + +__all__ = [ + "get_pretrained_path", + "set_pretrained_path", +] + + +def get_pretrained_path(model: Module) -> Optional[str]: + return getattr(model, "_pretrained", None) + + +def set_pretrained_path(model: Module, path: str) -> None: + setattr(model, "_pretrained", path) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index f29e997da..a03334b28 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -11,6 +11,7 @@ from torch.utils._pytree import tree_map from colossalai.logging import get_dist_logger from .construction import ConstructorManager +from .pretrained import PretrainedManager import colossalai._analyzer._subclasses._meta_registration # noqa @@ -595,11 +596,13 @@ class LazyInitContext: ) ConstructorManager.apply(overrides) + PretrainedManager.inject() def __exit__(self, exc_type, exc_val, exc_tb): self.tensor_cls.default_device = self.old_default_device LazyInitContext._replaced = False ConstructorManager.clear() + PretrainedManager.recover() @staticmethod def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: diff --git a/colossalai/lazy/pretrained.py b/colossalai/lazy/pretrained.py new file mode 100644 index 000000000..21d44d424 --- /dev/null +++ b/colossalai/lazy/pretrained.py @@ -0,0 +1,309 @@ +import os +from typing import Callable, Optional, Union + +import torch +from torch.nn import Module + +from colossalai.interface import pretrained as pretrained_interface + + +class PretrainedManager: + old_from_pretrained: Optional[Callable] = None + + @staticmethod + def inject() -> None: + try: + from transformers.modeling_utils import PreTrainedModel + except ImportError: + return + # recover bound method to plain function + PretrainedManager.old_from_pretrained = PreTrainedModel.from_pretrained.__func__ + PreTrainedModel.from_pretrained = new_from_pretrained + + @staticmethod + def recover() -> None: + try: + from transformers.modeling_utils import PreTrainedModel + except ImportError: + return + # convert plain function to class method + PreTrainedModel.from_pretrained = classmethod(PretrainedManager.old_from_pretrained) + PretrainedManager.old_from_pretrained = None + + +@classmethod +def new_from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs +) -> Module: + from transformers import GenerationConfig + from transformers.configuration_utils import PretrainedConfig + from transformers.modeling_utils import ( + ContextManagers, + _add_variant, + cached_file, + download_url, + has_file, + is_offline_mode, + is_remote_url, + no_init_weights, + ) + from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + is_safetensors_available, + logging, + ) + + logger = logging.get_logger(__name__) + + config = kwargs.pop("config", None) + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + _fast_init = kwargs.pop("_fast_init", True) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + if len(kwargs) > 0: + logger.warning(f"Below kwargs may be ignored: {list(kwargs.keys())}") + + from_pt = True + + user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + else: + model_kwargs = kwargs + + if commit_hash is None: + commit_hash = getattr(config, "_commit_hash", None) + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + else: + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + pass + elif use_safetensors: + raise EnvironmentError( + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + pass + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "use_auth_token": use_auth_token, + } + if variant is not None and has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}" + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}." + ) + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + if from_pt: + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + dtype_orig = None + + if torch_dtype is not None: + if not isinstance(torch_dtype, torch.dtype): + raise ValueError(f"`torch_dtype` can be either `torch.dtype` or `None`, but received {torch_dtype}") + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + config.name_or_path = pretrained_model_name_or_path + + # Instantiate model. + init_contexts = [no_init_weights(_enable=_fast_init)] + + with ContextManagers(init_contexts): + model = cls(config, *model_args, **model_kwargs) + + if from_pt: + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate(): + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except (OSError, TypeError): + logger.info("Generation config file not found, using a generation config created from the model config.") + + # set pretrained path + if resolved_archive_file: + pretrained_interface.set_pretrained_path(model, resolved_archive_file) + + return model diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index d66dec113..634e81bb2 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -3,11 +3,13 @@ import os import pytest import torch import torch.distributed as dist +from transformers import LlamaForCausalLM from utils import shared_tempdir import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin +from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.testing import ( check_state_dict_equal, @@ -120,11 +122,29 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) +def exam_lazy_from_pretrained(): + llama_path = os.environ["LLAMA_PATH"] + plugin = GeminiPlugin() + booster = Booster(plugin=plugin) + orig_model = LlamaForCausalLM.from_pretrained(llama_path) + orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()} + with LazyInitContext(): + model = LlamaForCausalLM.from_pretrained(llama_path) + model, *_ = booster.boost(model) + with shared_tempdir() as tempdir: + save_path = os.path.join(tempdir, "model.pt") + booster.save_model(model, save_path, shard=False) + dist.barrier() + state_dict = torch.load(save_path, map_location="cpu") + check_state_dict_equal(state_dict, orig_state_dict, False) + + def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() exam_state_dict_with_origin() + exam_lazy_from_pretrained() @pytest.mark.dist diff --git a/tests/test_lazy/test_from_pretrained.py b/tests/test_lazy/test_from_pretrained.py new file mode 100644 index 000000000..623dd82c5 --- /dev/null +++ b/tests/test_lazy/test_from_pretrained.py @@ -0,0 +1,31 @@ +import os + +from transformers import BertForPreTraining, LlamaForCausalLM + +import colossalai.interface.pretrained as pretrained_utils +from colossalai.lazy import LazyInitContext + + +def test_lazy_from_pretrained(): + # test from cached file, unsharded + model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny") + with LazyInitContext(): + deffered_model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny") + pretrained_path = pretrained_utils.get_pretrained_path(deffered_model) + assert os.path.isfile(pretrained_path) + for p, lazy_p in zip(model.parameters(), deffered_model.parameters()): + assert p.shape == lazy_p.shape + + # test from local file, sharded + llama_path = os.environ["LLAMA_PATH"] + model = LlamaForCausalLM.from_pretrained(llama_path) + with LazyInitContext(): + deffered_model = LlamaForCausalLM.from_pretrained(llama_path) + pretrained_path = pretrained_utils.get_pretrained_path(deffered_model) + assert os.path.isfile(pretrained_path) + for p, lazy_p in zip(model.parameters(), deffered_model.parameters()): + assert p.shape == lazy_p.shape + + +if __name__ == "__main__": + test_lazy_from_pretrained()