[Tensor ] Add 1Drow weight reshard by spec (#854)

This commit is contained in:
Ziyue Jiang
2022-04-24 18:30:20 +08:00
committed by GitHub
parent d7e0303d1e
commit bcc8655021
5 changed files with 41 additions and 11 deletions

View File

@@ -1,10 +1,11 @@
from zmq import device
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.nn import CheckpointModule
from .utils.dummy_data_generator import DummyDataGenerator
from .registry import non_distributed_component_funcs
from colossalai.utils.cuda import get_current_device
class SimpleNet(CheckpointModule):
"""
@@ -25,8 +26,8 @@ class SimpleNet(CheckpointModule):
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 4)
label = torch.randint(low=0, high=2, size=(16,))
data = torch.rand(16, 4, device=get_current_device())
label = torch.randint(low=0, high=2, size=(16,), device=get_current_device())
return data, label