mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[lazy] support from_pretrained (#4801)
* [lazy] patch from pretrained * [lazy] fix from pretrained and add tests * [devops] update ci
This commit is contained in:
@@ -3,11 +3,13 @@ import os
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import LlamaForCausalLM
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import (
|
||||
check_state_dict_equal,
|
||||
@@ -120,11 +122,29 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||
booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard)
|
||||
|
||||
|
||||
def exam_lazy_from_pretrained():
|
||||
llama_path = os.environ["LLAMA_PATH"]
|
||||
plugin = GeminiPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
orig_model = LlamaForCausalLM.from_pretrained(llama_path)
|
||||
orig_state_dict = {k: v.half() for k, v in orig_model.state_dict().items()}
|
||||
with LazyInitContext():
|
||||
model = LlamaForCausalLM.from_pretrained(llama_path)
|
||||
model, *_ = booster.boost(model)
|
||||
with shared_tempdir() as tempdir:
|
||||
save_path = os.path.join(tempdir, "model.pt")
|
||||
booster.save_model(model, save_path, shard=False)
|
||||
dist.barrier()
|
||||
state_dict = torch.load(save_path, map_location="cpu")
|
||||
check_state_dict_equal(state_dict, orig_state_dict, False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
exam_state_dict()
|
||||
exam_state_dict_with_origin()
|
||||
exam_lazy_from_pretrained()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
Reference in New Issue
Block a user