[tensor]fix test_linear (#826)

This commit is contained in:
Ziyue Jiang
2022-04-21 17:18:56 +08:00
committed by GitHub
parent 1a9e2c2dff
commit 8e6fdb4f29
2 changed files with 10 additions and 6 deletions

View File

@@ -19,8 +19,9 @@ def colo_linear(types, args, kwargs, pg):
bias = None
else:
bias = kwargs.get('bias', None)
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
# Add communication logic before and after linear call.
if isinstance(weight, ColoTensor):