[tensor] redirect .data.__get__ to a tensor instance (#1239)

This commit is contained in:
HELSON
2022-07-11 11:41:29 +08:00
committed by GitHub
parent 20da6e48c8
commit f6add9b720
2 changed files with 20 additions and 12 deletions

View File

@@ -86,6 +86,7 @@ def _run_tensor_shard_init(world_size):
tensor_spec = ColoTensorSpec(pg, dist_attr=shard_attr)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_dist_spec(distspec.replicate())
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"