mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[booster] gemini plugin support shard checkpoint (#3610)
* gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint --------- Co-authored-by: luchen <luchen@luchendeMBP.lan> Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
import tempfile
|
||||
import pytest
|
||||
import torch
|
||||
import logging
|
||||
from torch.optim import Adam
|
||||
from torchvision.models import resnet18
|
||||
from pathlib import Path
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
# ========
|
||||
# Note:
|
||||
# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now
|
||||
@@ -83,7 +88,6 @@ def test_sharded_checkpoint(use_safetensors: bool):
|
||||
suffix = ".bin"
|
||||
WEIGHTS_INDEX_NAME = "model.bin.index.json"
|
||||
|
||||
# model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix)
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
||||
|
||||
@@ -104,6 +108,87 @@ def test_sharded_checkpoint(use_safetensors: bool):
|
||||
recursive_check(model.state_dict(), new_model.state_dict())
|
||||
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('model_name', ['bert'])
|
||||
@parameterize('use_safetensors', [True, False])
|
||||
def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: bool):
|
||||
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification
|
||||
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, *_ = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
bert_model = model_builder()
|
||||
bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name)
|
||||
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
bert_model = ZeroDDP(bert_model, gemini_manager)
|
||||
bert_model.train()
|
||||
|
||||
ckpt_io = GeminiCheckpointIO()
|
||||
if ckpt_io.coordinator.is_master():
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=use_safetensors)
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name)
|
||||
recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict())
|
||||
|
||||
model_ckpt_dir.cleanup()
|
||||
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
@parameterize('use_safetensors', [True, False])
|
||||
def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, *_ = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
new_model = model_builder()
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
model.train()
|
||||
|
||||
new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100)
|
||||
new_chunk_manager = ChunkManager(new_config_dict)
|
||||
new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager)
|
||||
new_model = ZeroDDP(new_model, new_gemini_manager)
|
||||
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
ckpt_io = GeminiCheckpointIO()
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors)
|
||||
|
||||
# load model
|
||||
if ckpt_io.coordinator.is_master():
|
||||
ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True)
|
||||
model_dict = model.state_dict(only_rank_0=True)
|
||||
new_model_dict = new_model.state_dict(only_rank_0=True)
|
||||
recursive_check(model_dict, new_model_dict)
|
||||
|
||||
model_ckpt_dir.cleanup()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_state_dict()
|
||||
hf_load_colossalai_checkpoint()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
# do recursive check for the optimizer state dict
|
||||
# if the value is a dict, compare its values
|
||||
@@ -117,10 +202,14 @@ def recursive_check(d1, d2):
|
||||
elif isinstance(v, list):
|
||||
for i in range(len(v)):
|
||||
if isinstance(v[i], torch.Tensor):
|
||||
v[i] = v[i].to("cpu")
|
||||
d2[k][i] = d2[k][i].to("cpu")
|
||||
assert torch.equal(v[i], d2[k][i])
|
||||
else:
|
||||
assert v[i] == d2[k][i]
|
||||
elif isinstance(v, torch.Tensor):
|
||||
v = v.to("cpu")
|
||||
d2[k] = d2[k].to("cpu")
|
||||
assert torch.equal(v, d2[k])
|
||||
else:
|
||||
assert v == d2[k]
|
||||
|
@@ -31,14 +31,13 @@ def exam_state_dict(placement_policy, model_name: str):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
accumulated_keys = set()
|
||||
# ensure number of shards > 1
|
||||
for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
|
||||
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
|
||||
for key, value in shard.items():
|
||||
assert key not in accumulated_keys, f"key `{key}` is duplicated."
|
||||
accumulated_keys.add(key)
|
||||
assert key in zero_dict, f"{key} not in ZeRO dictionary."
|
||||
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
Reference in New Issue
Block a user