use double buffer to handle grad

This commit is contained in:
ver217
2022-03-15 17:07:35 +08:00
parent 0f5f5dd556
commit 9506a8beb2
4 changed files with 29 additions and 41 deletions

View File

@@ -1,12 +1,14 @@
from typing import Optional
import torch
from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Optional
@OPHOOKS.register_module
@@ -62,8 +64,8 @@ class ZeroHook(BaseOpHook):
if param.grad is not None:
if param.col_attr.bwd_count == 0:
# We haven't stored local accumulated grad yet
assert param.col_attr.grad is None
param.col_attr.grad = param.grad.data
assert param.col_attr.fp32_grad is None
param.col_attr.fp32_grad = param.grad.data
param.grad = None
else:
# We have stored local accumulated grad