mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[fx] add torchaudio test (#1369)
* [fx]add torchaudio test * [fx]add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test * [fx] add torchaudio test and test patches * Delete ~ * [fx] add patches and patches test * [fx] add patches and patches test * [fx] fix patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] fix rnn patches * [fx] merge upstream * [fx] fix import errors
This commit is contained in:
@@ -27,7 +27,7 @@ def save_checkpoint(dire: str,
|
||||
# save the dist context about the tensors in a new dict, while still maintain the original dict.
|
||||
for k, v in model_state.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
gather_tensor(v) # gather shared tensors to rank0
|
||||
# don't recover tensors in rank0, since the dict is only a copy of model
|
||||
|
||||
if rank == 0:
|
||||
|
@@ -34,7 +34,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
|
||||
setattr(colo_tensor, 'save_ready', True) # set saving signitrue
|
||||
|
||||
|
||||
def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
@@ -54,9 +54,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||
if dist.get_rank() == 0:
|
||||
colo_tensor.set_dist_spec(dist_spec)
|
||||
else:
|
||||
rep_tensor = ColoTensor(entire_data, ColoTensorSpec(
|
||||
pg=colo_tensor.get_process_group(),
|
||||
compute_attr=colo_tensor.compute_spec))
|
||||
rep_tensor = ColoTensor(
|
||||
entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec))
|
||||
rep_tensor.set_dist_spec(dist_spec)
|
||||
with torch.no_grad():
|
||||
colo_tensor.data.copy_(rep_tensor.data)
|
||||
|
Reference in New Issue
Block a user