mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[hotfix] fix test_stateful_tensor_mgr (#762)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
@@ -79,7 +79,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
|
||||
for t in to_free_tensor_list:
|
||||
if freed_cuda_model_data > to_free_cuda_model_data:
|
||||
if freed_cuda_model_data >= to_free_cuda_model_data:
|
||||
break
|
||||
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
|
||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||
|
Reference in New Issue
Block a user