mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
[booster] implemented the torch ddd + resnet example (#3232)
* [booster] implemented the torch ddd + resnet example * polish code
This commit is contained in:
@@ -8,8 +8,8 @@ from torch.optim import SGD
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.interface import OptimizerWrapper
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
@@ -34,7 +34,7 @@ def check_torch_ddp_plugin():
|
||||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
assert isinstance(model, DDP)
|
||||
assert isinstance(model.module, DDP)
|
||||
assert isinstance(optimizer, OptimizerWrapper)
|
||||
|
||||
output = model(**data)
|
||||
|
||||
@@ -42,8 +42,8 @@ def test_unsharded_checkpoint():
|
||||
new_optimizer = Adam(new_model.parameters(), lr=0.001)
|
||||
|
||||
# load the model and optimizer
|
||||
new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
|
||||
new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
||||
ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
|
||||
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
||||
|
||||
# do recursive check for the optimizer state dict
|
||||
# if the value is a dict, compare its values
|
||||
|
||||
Reference in New Issue
Block a user