[polish] polish singleton and global context (#500)

This commit is contained in:
Jiarui Fang
2022-03-23 18:03:39 +08:00
committed by GitHub
parent 9ec1ce6ab1
commit a445e118cf
18 changed files with 39 additions and 47 deletions

View File

@@ -1,12 +1,15 @@
import torch
import colossalai
import copy
import pytest
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
from colossalai.testing import assert_close_loose
from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
import copy
import pytest
from functools import partial

View File

@@ -7,7 +7,7 @@ import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group

View File

@@ -8,7 +8,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
BATCH_SIZE = 16
NUM_EXPERTS = 4

View File

@@ -6,7 +6,7 @@ import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group