mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
free param.grad
This commit is contained in:
parent
9506a8beb2
commit
ea6905a898
@ -1,5 +1,5 @@
|
|||||||
import functools
|
import functools
|
||||||
from ast import Try
|
from asyncio.log import logger
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ from colossalai.zero.sharded_param import ShardedParamV2
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||||
get_gradient_predivide_factor)
|
get_gradient_predivide_factor)
|
||||||
|
|
||||||
|
|
||||||
@ -218,6 +218,9 @@ class ShardedModelV2(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self._reduce_scatter_callback(param, new_grad)
|
self._reduce_scatter_callback(param, new_grad)
|
||||||
orig_grad_data.record_stream(self.comm_stream)
|
orig_grad_data.record_stream(self.comm_stream)
|
||||||
|
empty_grad = torch.empty_like(grad)
|
||||||
|
free_storage(empty_grad)
|
||||||
|
return empty_grad
|
||||||
|
|
||||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||||
if self.gradient_postdivide_factor > 1:
|
if self.gradient_postdivide_factor > 1:
|
||||||
|
Loading…
Reference in New Issue
Block a user