[NFC] polish code format

[NFC] polish code format
This commit is contained in:
binmakeswell 2023-02-15 23:21:36 +08:00 committed by GitHub
commit 30aee9c45d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 221 additions and 214 deletions

View File

@ -1,7 +1,8 @@
import click
from .launcher import run
from .check import check
from .benchmark import benchmark
from .check import check
from .launcher import run
class Arguments():

View File

@ -1,7 +1,9 @@
import click
from .run import launch_multi_processes
from colossalai.context import Config
from .run import launch_multi_processes
@click.command(help="Launch distributed training on a single node or multiple nodes",
context_settings=dict(ignore_unknown_options=True))

View File

@ -1,3 +1,5 @@
from typing import Tuple
import torch
import torch.distributed as dist
@ -5,8 +7,6 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor import ProcessGroup
from typing import Tuple
def _check_sanity():
from colossalai.core import global_context as gpc

View File

@ -2,10 +2,11 @@ import math
import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
def _check_summa_env_var(summa_dim):

View File

@ -4,8 +4,9 @@
from torch import distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module

View File

@ -3,9 +3,10 @@
import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .initializer_tensor import Initializer_Tensor
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module

View File

@ -1,7 +1,8 @@
from typing import Iterable
import torch.distributed as dist
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from typing import Iterable
def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):