[tensor] fix kwargs in colo_tensor torch_funtion (#825)

This commit is contained in:
Ziyue Jiang
2022-04-21 16:47:35 +08:00
committed by GitHub
parent eb1b89908c
commit 1a9e2c2dff
2 changed files with 3 additions and 4 deletions

View File

@@ -63,6 +63,6 @@ class ColoTensor(object):
kwargs = {}
kwargs = {
kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
}
return func(*args, **kwargs)