mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[tensor] customized op returns ColoTensor (#875)
* [tensor] customized op returns ColoTensor * polish * polish code
This commit is contained in:
@@ -1,19 +1,13 @@
|
||||
from cProfile import label
|
||||
from statistics import mode
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import ColoInitContext
|
||||
|
||||
import torch.distributed as dist
|
||||
from functools import partial
|
||||
|
||||
|
||||
|
@@ -53,11 +53,11 @@ def test_linear():
|
||||
|
||||
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
|
||||
out = fc(input_tensor)
|
||||
loss = out.sum()
|
||||
loss = torch.sum(out)
|
||||
loss.backward()
|
||||
|
||||
out_ref = fc_ref(input_ref)
|
||||
loss_ref = out_ref.sum()
|
||||
loss_ref = torch.sum(out_ref)
|
||||
loss_ref.backward()
|
||||
|
||||
assert (loss_ref == loss)
|
||||
|
Reference in New Issue
Block a user