mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
use double buffer to handle grad
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
@@ -62,8 +64,8 @@ class ZeroHook(BaseOpHook):
|
||||
if param.grad is not None:
|
||||
if param.col_attr.bwd_count == 0:
|
||||
# We haven't stored local accumulated grad yet
|
||||
assert param.col_attr.grad is None
|
||||
param.col_attr.grad = param.grad.data
|
||||
assert param.col_attr.fp32_grad is None
|
||||
param.col_attr.fp32_grad = param.grad.data
|
||||
param.grad = None
|
||||
else:
|
||||
# We have stored local accumulated grad
|
||||
|
Reference in New Issue
Block a user