[plugin] a workaround for zero plugins' optimizer checkpoint (#3780)

* [test] refactor torch ddp checkpoint test

* [plugin] update low level zero optim checkpoint

* [plugin] update gemini optim checkpoint
This commit is contained in:
Hongxin Liu 2023-05-19 19:42:31 +08:00 committed by GitHub
parent 60e6a154bc
commit 3c07a2846e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 128 additions and 82 deletions

View File

@ -52,8 +52,16 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """
# TODO(ver217): optimizer state dict is sharded # TODO(ver217): optimizer state dict is sharded
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
""" """
Save model to checkpoint but only on master process. Save model to checkpoint but only on master process.

View File

@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
@ -32,8 +32,17 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
""" """
Save optimizer to checkpoint but only on master process. Save optimizer to checkpoint but only on master process.
""" """
# TODO(ver217): optimizer state dict is sharded # TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) warnings.warn(
'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
class LowLevelZeroModel(ModelWrapper): class LowLevelZeroModel(ModelWrapper):

View File

@ -1,87 +1,95 @@
import tempfile import os
import pytest import pytest
import torch import torch
import torch.distributed as dist
from utils import shared_tempdir
import colossalai import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device from colossalai.zero import ZeroDDP
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['bert']) @parameterize('model_name', ['transformers_bert_for_sequence_classification'])
@parameterize('use_safetensors', [True, False]) @parameterize('use_safetensors', [False, True])
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool): def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
from transformers import BertForSequenceClassification from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn()
model_ckpt_dir = tempfile.TemporaryDirectory() with shared_tempdir() as tempdir:
get_components_func = non_distributed_component_funcs.get_callable(model_name) pretrained_path = os.path.join(tempdir, 'pretrained')
model_builder, *_ = get_components_func() bert_model.config.save_pretrained(save_directory=pretrained_path)
with ColoInitContext(device=(get_current_device())):
bert_model = model_builder()
bert_model.config.save_pretrained(save_directory=(model_ckpt_dir.name))
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100) # TODO(ver217): use boost api
chunk_manager = ChunkManager(config_dict) config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
gemini_manager = GeminiManager(placement_policy, chunk_manager) chunk_manager = ChunkManager(config_dict)
bert_model = ZeroDDP(bert_model, gemini_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model.train() bert_model = ZeroDDP(bert_model, gemini_manager)
bert_model.train()
ckpt_io = GeminiCheckpointIO() ckpt_io = GeminiCheckpointIO()
if ckpt_io.coordinator.is_master():
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
ckpt_io.save_model(bert_model, (model_ckpt_dir.name), ckpt_io.save_model(bert_model, (pretrained_path),
True, True,
True, True,
'', (model_size / 3), '', (model_size / 3),
use_safetensors=use_safetensors) use_safetensors=use_safetensors)
new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name) dist.barrier()
check_state_dict_equal(bert_model.state_dict(only_rank_0=True, dtype=(torch.float32)), new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
new_bert_model.state_dict(), False) new_bert_model.state_dict(), False)
model_ckpt_dir.cleanup()
@parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert']) @parameterize('shard', [True, False])
@parameterize('use_safetensors', [True, False]) @parameterize('model_name', ['transformers_gpt'])
def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): def exam_state_dict(placement_policy, shard: bool, model_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name) (model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
model_builder, *_ = get_components_func() criterion = lambda x: x.mean()
with ColoInitContext(device=(get_current_device())): plugin = GeminiPlugin(placement_policy=placement_policy)
model = model_builder() booster = Booster(plugin=plugin)
new_model = model_builder()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
model.train() model = model_fn()
#new model new_model = model_fn()
new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100) optimizer = HybridAdam(model.parameters(), lr=0.001)
new_chunk_manager = ChunkManager(new_config_dict) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager) new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
new_model = ZeroDDP(new_model, new_gemini_manager) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
model_ckpt_dir = tempfile.TemporaryDirectory() data = data_gen_fn()
ckpt_io = GeminiCheckpointIO() data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 output = model(**data)
ckpt_io.save_model(model, (model_ckpt_dir.name), output = output_transform_fn(output)
True, output_key = list(output.keys())[0]
True, loss = criterion(output[output_key])
'epoch', (model_size / 3),
use_safetensors=use_safetensors)
if ckpt_io.coordinator.is_master(): booster.backward(loss, optimizer)
ckpt_io.load_model(new_model, (model_ckpt_dir.name), strict=True) optimizer.step()
model_dict = model.state_dict(only_rank_0=True)
new_model_dict = new_model.state_dict(only_rank_0=True) with shared_tempdir() as tempdir:
check_state_dict_equal(model_dict, new_model_dict, False) model_ckpt_path = f"{tempdir}/model"
model_ckpt_dir.cleanup() optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
dist.barrier()
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
new_model.unwrap().state_dict(only_rank_0=False), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
@ -92,7 +100,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [4, 4]) @pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size): def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size) spawn(run_dist, world_size)

View File

@ -1,13 +1,11 @@
import tempfile
import pytest
import torch import torch
import torch.distributed as dist
from torchvision.models import resnet18 from torchvision.models import resnet18
from utils import shared_tempdir
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroCheckpointIO
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import ( from colossalai.testing import (
check_state_dict_equal, check_state_dict_equal,
@ -20,7 +18,8 @@ from colossalai.testing import (
@clear_cache_before_run() @clear_cache_before_run()
@parameterize('stage', [2]) @parameterize('stage', [2])
def check_low_level_zero_checkpointIO(stage: int): @parameterize('shard', [True, False])
def check_low_level_zero_checkpointIO(stage: int, shard: bool):
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32) plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model = resnet18() model = resnet18()
@ -34,17 +33,25 @@ def check_low_level_zero_checkpointIO(stage: int):
loss = criterion(output) loss = criterion(output)
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_model(model, model_ckpt_path, shard=shard)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
dist.barrier()
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() new_model = resnet18()
ckpt_io = LowLevelZeroCheckpointIO() new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
new_model = resnet18() booster.load_model(new_model, model_ckpt_path)
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
_, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer) if not shard:
if ckpt_io.coordinator.is_master(): booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):

View File

@ -1,10 +1,9 @@
import tempfile
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
from torchvision.models import resnet18 from torchvision.models import resnet18
from utils import shared_tempdir
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
@ -35,11 +34,7 @@ def check_torch_ddp_checkpointIO(shard: bool):
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
with tempfile.TemporaryDirectory() as tempdir: with shared_tempdir() as tempdir:
obj = [tempdir]
dist.broadcast_object_list(obj, src=0)
tempdir = obj[0] # use the same directory on all ranks
model_ckpt_path = f"{tempdir}/model" model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer" optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
@ -66,8 +61,6 @@ def check_torch_ddp_checkpointIO(shard: bool):
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
dist.barrier()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')

View File

@ -0,0 +1,21 @@
import tempfile
from contextlib import contextmanager, nullcontext
from typing import Iterator
import torch.distributed as dist
@contextmanager
def shared_tempdir() -> Iterator[str]:
"""
A temporary directory that is shared across all processes.
"""
ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
with ctx_fn() as tempdir:
try:
obj = [tempdir]
dist.broadcast_object_list(obj, src=0)
tempdir = obj[0] # use the same directory on all ranks
yield tempdir
finally:
dist.barrier()