[hotfix] fix test_stateful_tensor_mgr (#762)

This commit is contained in:
ver217
2022-04-14 15:50:09 +08:00
committed by GitHub
parent 6978980f6d
commit dcca614eee
2 changed files with 20 additions and 13 deletions

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Dict
from typing import List, Optional
import torch
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
@@ -79,7 +79,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
for t in to_free_tensor_list:
if freed_cuda_model_data > to_free_cuda_model_data:
if freed_cuda_model_data >= to_free_cuda_model_data:
break
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
colo_model_data_tensor_move_inline(t, torch.device('cpu'))