diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py index c85f53cc4..5234b6b1a 100644 --- a/colossalai/nn/layer/base_layer.py +++ b/colossalai/nn/layer/base_layer.py @@ -1,11 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from contextlib import contextmanager + import torch.nn as nn from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from contextlib import contextmanager class ParallelLayer(nn.Module): diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 3be057b3a..789ce8ab3 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -1,6 +1,7 @@ -from colossalai.device.device_mesh import DeviceMesh import torch +from colossalai.device.device_mesh import DeviceMesh + def test_device_mesh(): physical_mesh_id = torch.arange(0, 16).reshape(2, 8)