[fx] add torchaudio test (#1369)

* [fx]add torchaudio test

* [fx]add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test

* [fx] add torchaudio test and test patches

* Delete ~

* [fx] add patches and patches test

* [fx] add patches and patches test

* [fx] fix patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] fix rnn patches

* [fx] merge upstream

* [fx] fix import errors
This commit is contained in:
Super Daniel
2022-07-27 11:03:14 +08:00
committed by GitHub
parent fb6f085907
commit be229217ce
18 changed files with 609 additions and 16 deletions

View File

@@ -27,7 +27,7 @@ def save_checkpoint(dire: str,
# save the dist context about the tensors in a new dict, while still maintain the original dict.
for k, v in model_state.items():
if isinstance(v, ColoTensor):
gather_tensor(v) # gather shared tensors to rank0
gather_tensor(v) # gather shared tensors to rank0
# don't recover tensors in rank0, since the dict is only a copy of model
if rank == 0:

View File

@@ -34,7 +34,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
dist.barrier()
if dist.get_rank() == 0:
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
@@ -54,9 +54,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
if dist.get_rank() == 0:
colo_tensor.set_dist_spec(dist_spec)
else:
rep_tensor = ColoTensor(entire_data, ColoTensorSpec(
pg=colo_tensor.get_process_group(),
compute_attr=colo_tensor.compute_spec))
rep_tensor = ColoTensor(
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
rep_tensor.set_dist_spec(dist_spec)
with torch.no_grad():
colo_tensor.data.copy_(rep_tensor.data)