[misc] fit torch api upgradation and remove legecy import (#6093)

* [amp] fit torch's new api

* [amp] fix api call

* [amp] fix api call

* [misc] fit torch pytree api upgrade

* [misc] remove legacy import

* [misc] fit torch amp api

* [misc] fit torch amp api
This commit is contained in:
Hongxin Liu
2024-10-18 16:48:52 +08:00
committed by GitHub
parent 5ddad486ca
commit 58d8b8a2dd
7 changed files with 20 additions and 12 deletions

View File

@@ -1,10 +1,5 @@
import torch.nn
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float
@@ -27,6 +22,12 @@ class RuntimeMemTracer:
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__()
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
self.module = module
self.dtype = dtype
self._gradstat = GradMemStats()

View File

@@ -8,7 +8,6 @@ import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager
@@ -172,6 +171,8 @@ class AutoPlacementPolicy(PlacementPolicy):
Returns:
int: the volume of memory that is evicted
"""
from colossalai.legacy.utils.memory import colo_device_memory_capacity
start = time()
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]