[shardformer] fix submodule replacement bug when enabling pp (#4544)

This commit is contained in:
Baizhou Zhang
2023-08-31 09:57:18 +08:00
committed by GitHub
parent ec18fc7340
commit 2c787d7f47
5 changed files with 21 additions and 12 deletions

View File

@@ -4,6 +4,7 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@@ -105,6 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()

View File

@@ -4,6 +4,7 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@@ -97,6 +98,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()

View File

@@ -6,6 +6,7 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@@ -107,6 +108,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()