[format] applied code formatting on changed files in pull request 5088 (#5127)

Co-authored-by: github-actions <github-actions@github.com>
This commit is contained in:
github-actions[bot]
2023-11-29 13:38:37 +08:00
committed by GitHub
parent 9110406a47
commit d10ee42f68
2 changed files with 23 additions and 10 deletions

View File

@@ -1,7 +1,7 @@
from contextlib import nullcontext
from typing import Optional
import pytest
import pytest
import torch
import torch.distributed as dist
@@ -11,8 +11,6 @@ from colossalai.booster.plugin import GeminiPlugin
from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@@ -26,7 +24,13 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
ctx = nullcontext()
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
enable_all_optimization = True if tp_size > 1 else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
plugin = GeminiPlugin(
max_norm=1.0,
initial_scale=2**5,
tp_size=tp_size,
extra_dp_size=extra_dp_size,
enable_all_optimization=enable_all_optimization,
)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
@@ -66,7 +70,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
@parameterize("init_method", ["none"])
@parameterize("zero_size", [2])
@parameterize("tp_size", [2])
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
def check_gemini_plugin(
subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1
):
"""check gemini plugin over model zoo
Args:
@@ -161,6 +167,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
@@ -168,4 +175,4 @@ def test_gemini_plugin_3d(early_stop: bool = True):
if __name__ == "__main__":
test_gemini_plugin(early_stop=False)
test_gemini_plugin(early_stop=False)