[lazy] support from_pretrained (#4801)

* [lazy] patch from pretrained

* [lazy] fix from pretrained and add tests

* [devops] update ci
This commit is contained in:
Hongxin Liu
2023-09-26 11:04:11 +08:00
committed by GitHub
parent 64a08b2dc3
commit 4965c0dabd
11 changed files with 397 additions and 5 deletions

View File

@@ -3,11 +3,13 @@ import os
import pytest
import torch
import torch.distributed as dist
from transformers import LlamaForCausalLM
from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import (
check_state_dict_equal,
@@ -120,11 +122,29 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
def exam_lazy_from_pretrained():
llama_path = os.environ["LLAMA_PATH"]
plugin = GeminiPlugin()
booster = Booster(plugin=plugin)
orig_model = LlamaForCausalLM.from_pretrained(llama_path)
orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()}
with LazyInitContext():
model = LlamaForCausalLM.from_pretrained(llama_path)
model, *_ = booster.boost(model)
with shared_tempdir() as tempdir:
save_path = os.path.join(tempdir, "model.pt")
booster.save_model(model, save_path, shard=False)
dist.barrier()
state_dict = torch.load(save_path, map_location="cpu")
check_state_dict_equal(state_dict, orig_state_dict, False)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict()
exam_state_dict_with_origin()
exam_lazy_from_pretrained()
@pytest.mark.dist