mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 19:16:42 +00:00
[checkpointio] support huggingface from_pretrained for all plugins (#4606)
This commit is contained in:
parent
0a94fcd351
commit
e79b1e80e2
@ -18,6 +18,7 @@ from colossalai.checkpoint_io.utils import (
|
|||||||
get_optimizer_base_filenames,
|
get_optimizer_base_filenames,
|
||||||
get_shard_filename,
|
get_shard_filename,
|
||||||
load_shard_state_dict,
|
load_shard_state_dict,
|
||||||
|
save_config_file,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
save_state_dict_shards,
|
save_state_dict_shards,
|
||||||
)
|
)
|
||||||
@ -111,6 +112,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
|
save_config_file(model.module, checkpoint_path)
|
||||||
logging.info(f"The model is split into checkpoint shards. "
|
logging.info(f"The model is split into checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}.")
|
f"index located at {save_index_file}.")
|
||||||
|
@ -23,6 +23,7 @@ from .utils import (
|
|||||||
load_state_dict,
|
load_state_dict,
|
||||||
load_state_dict_into_model,
|
load_state_dict_into_model,
|
||||||
load_states_into_optimizer,
|
load_states_into_optimizer,
|
||||||
|
save_config_file,
|
||||||
save_param_groups,
|
save_param_groups,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
save_state_dict_shards,
|
save_state_dict_shards,
|
||||||
@ -185,6 +186,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
|
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
|
save_config_file(model, checkpoint_path, is_master=True)
|
||||||
logging.info(f"The model is going to be split to checkpoint shards. "
|
logging.info(f"The model is going to be split to checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}.")
|
f"index located at {save_index_file}.")
|
||||||
|
@ -1,129 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.optim import Adam
|
|
||||||
from utils import shared_tempdir
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.booster import Booster
|
|
||||||
from colossalai.booster.plugin import HybridParallelPlugin
|
|
||||||
from colossalai.shardformer.layer.utils import Randomizer
|
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
|
||||||
from colossalai.testing import (
|
|
||||||
check_state_dict_equal,
|
|
||||||
clear_cache_before_run,
|
|
||||||
parameterize,
|
|
||||||
rerun_if_address_is_in_use,
|
|
||||||
spawn,
|
|
||||||
)
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
|
||||||
|
|
||||||
|
|
||||||
def exam_from_pretrained(model_fn,
|
|
||||||
data_gen_fn,
|
|
||||||
output_transform_fn,
|
|
||||||
loss_fn,
|
|
||||||
test_config,
|
|
||||||
shard=True,
|
|
||||||
size_per_shard=32):
|
|
||||||
|
|
||||||
def _criterion(outputs, inputs):
|
|
||||||
outputs = output_transform_fn(outputs)
|
|
||||||
loss = criterion(outputs)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def _preprocess_data(data):
|
|
||||||
if booster.plugin.stage_manager is not None:
|
|
||||||
for k, v in data.items():
|
|
||||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
|
||||||
new_shape = [1] * v.dim()
|
|
||||||
new_shape[0] = 4
|
|
||||||
data[k] = v.to('cuda').repeat(*new_shape)
|
|
||||||
return iter([data])
|
|
||||||
else:
|
|
||||||
return {k: v.cuda() for k, v in data.items()}
|
|
||||||
|
|
||||||
model = model_fn()
|
|
||||||
optimizer = Adam((model.parameters()), lr=0.001)
|
|
||||||
criterion = loss_fn
|
|
||||||
plugin = HybridParallelPlugin(**test_config)
|
|
||||||
booster = Booster(plugin=plugin)
|
|
||||||
|
|
||||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
|
||||||
|
|
||||||
data = data_gen_fn()
|
|
||||||
model.train()
|
|
||||||
if booster.plugin.stage_manager is not None:
|
|
||||||
booster.execute_pipeline(_preprocess_data(data),
|
|
||||||
model,
|
|
||||||
_criterion,
|
|
||||||
optimizer,
|
|
||||||
return_loss=True,
|
|
||||||
return_outputs=False)
|
|
||||||
else:
|
|
||||||
output = model(**_preprocess_data(data))
|
|
||||||
loss = criterion(output)
|
|
||||||
optimizer.backward(loss)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
with shared_tempdir() as tempdir:
|
|
||||||
|
|
||||||
model_ckpt_path = f"{tempdir}/model"
|
|
||||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path)
|
|
||||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
|
||||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
|
||||||
|
|
||||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
|
||||||
|
|
||||||
Randomizer.reset_index()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
|
||||||
@parameterize('test_config', [{
|
|
||||||
'tp_size': 4,
|
|
||||||
'pp_size': 1,
|
|
||||||
'precision': 'fp32',
|
|
||||||
}, {
|
|
||||||
'tp_size': 2,
|
|
||||||
'pp_size': 2,
|
|
||||||
'num_microbatches': 4,
|
|
||||||
'precision': 'fp16',
|
|
||||||
'initial_scale': 1
|
|
||||||
}, {
|
|
||||||
'tp_size': 2,
|
|
||||||
'pp_size': 1,
|
|
||||||
'zero_stage': 2,
|
|
||||||
'precision': 'fp16',
|
|
||||||
'initial_scale': 1
|
|
||||||
}, {
|
|
||||||
'tp_size': 1,
|
|
||||||
'pp_size': 2,
|
|
||||||
'num_microbatches': 4,
|
|
||||||
'zero_stage': 1,
|
|
||||||
'precision': 'fp16',
|
|
||||||
'initial_scale': 1
|
|
||||||
}])
|
|
||||||
def run_test(test_config):
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
||||||
exam_from_pretrained(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
|
||||||
clear_layout_converter()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
|
||||||
config = {}
|
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
||||||
run_test()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@pytest.mark.parametrize('world_size', [4])
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_huggingface_compatibility(world_size):
|
|
||||||
spawn(run_dist, world_size)
|
|
@ -0,0 +1,83 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from utils import shared_tempdir
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.testing import (
|
||||||
|
check_state_dict_equal,
|
||||||
|
clear_cache_before_run,
|
||||||
|
parameterize,
|
||||||
|
rerun_if_address_is_in_use,
|
||||||
|
spawn,
|
||||||
|
)
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@clear_cache_before_run()
|
||||||
|
@parameterize('model_name', ['transformers_gpt'])
|
||||||
|
@parameterize('plugin_type', ['ddp', 'zero', 'gemini'])
|
||||||
|
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
|
||||||
|
(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
|
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||||
|
criterion = loss_fn
|
||||||
|
|
||||||
|
if plugin_type == 'ddp':
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
elif plugin_type == 'zero':
|
||||||
|
plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32)
|
||||||
|
elif plugin_type == 'gemini':
|
||||||
|
plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.")
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
model = model_fn().cuda()
|
||||||
|
model_huggingface_cls = model.__class__
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||||
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||||
|
|
||||||
|
data = data_gen_fn()
|
||||||
|
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
|
||||||
|
output = model(**data)
|
||||||
|
loss = criterion(output)
|
||||||
|
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
with shared_tempdir() as tempdir:
|
||||||
|
|
||||||
|
model_ckpt_path = f"{tempdir}/model"
|
||||||
|
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
new_model = model_huggingface_cls.from_pretrained(model_ckpt_path)
|
||||||
|
new_model = new_model.cuda()
|
||||||
|
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
|
||||||
|
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||||
|
|
||||||
|
if plugin_type == 'gemini':
|
||||||
|
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
|
||||||
|
new_model.unwrap().state_dict(only_rank_0=False), False)
|
||||||
|
else:
|
||||||
|
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
config = {}
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
exam_from_pretrained()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [2])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_huggingface_compatibility(world_size):
|
||||||
|
spawn(run_dist, world_size)
|
Loading…
Reference in New Issue
Block a user