mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[misc] fit torch api upgradation and remove legecy import (#6093)
* [amp] fit torch's new api * [amp] fix api call * [amp] fix api call * [misc] fit torch pytree api upgrade * [misc] remove legacy import * [misc] fit torch amp api * [misc] fit torch amp api
This commit is contained in:
@@ -1,10 +1,5 @@
|
||||
import torch.nn
|
||||
|
||||
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
|
||||
GradMemStats,
|
||||
GradMemTracerHook,
|
||||
ParamMemTracerHook,
|
||||
)
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import _cast_float
|
||||
|
||||
@@ -27,6 +22,12 @@ class RuntimeMemTracer:
|
||||
|
||||
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
|
||||
super().__init__()
|
||||
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
|
||||
GradMemStats,
|
||||
GradMemTracerHook,
|
||||
ParamMemTracerHook,
|
||||
)
|
||||
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self._gradstat = GradMemStats()
|
||||
|
@@ -8,7 +8,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.chunk import Chunk
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
@@ -172,6 +171,8 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
Returns:
|
||||
int: the volume of memory that is evicted
|
||||
"""
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
|
||||
start = time()
|
||||
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
|
||||
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
|
||||
|
Reference in New Issue
Block a user