[shardformer] fix submodule replacement bug when enabling pp (#4544)

This commit is contained in:
Baizhou Zhang
2023-08-31 09:57:18 +08:00
committed by GitHub
parent ec18fc7340
commit 2c787d7f47
5 changed files with 21 additions and 12 deletions

View File

@@ -7,6 +7,7 @@ from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
check_state_dict_equal,
@@ -100,6 +101,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
Randomizer.reset_index()
clear_layout_converter()