[booster] implemented the torch ddd + resnet example (#3232)

* [booster] implemented the torch ddd + resnet example

* polish code
This commit is contained in:
Frank Lee
2023-03-27 10:24:14 +08:00
committed by GitHub
parent 1a229045af
commit 73d3e4d309
22 changed files with 608 additions and 128 deletions

View File

@@ -8,8 +8,8 @@ from torch.optim import SGD
import colossalai
from colossalai.booster import Booster
from colossalai.booster.interface import OptimizerWrapper
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo
@@ -34,7 +34,7 @@ def check_torch_ddp_plugin():
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
assert isinstance(model, DDP)
assert isinstance(model.module, DDP)
assert isinstance(optimizer, OptimizerWrapper)
output = model(**data)

View File

@@ -42,8 +42,8 @@ def test_unsharded_checkpoint():
new_optimizer = Adam(new_model.parameters(), lr=0.001)
# load the model and optimizer
new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
# do recursive check for the optimizer state dict
# if the value is a dict, compare its values