[zero] zero init context (#321)

* add zero init context

* add more flags for zero init context
fix bug of repeated converting param to ShardedParamV2

* polish code
This commit is contained in:
Jiarui Fang
2022-03-07 16:14:40 +08:00
committed by Frank Lee
parent 73bff11288
commit de0468c7a8
4 changed files with 173 additions and 5 deletions

View File

@@ -1,5 +1,3 @@
from typing import Optional, Tuple, Union
import numpy
import torch
import torch.distributed as dist
@@ -8,8 +6,6 @@ from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param import ShardedTensor
from typing import Union, Tuple, Optional
import numpy
class ShardedParamV2(object):
@@ -35,7 +31,10 @@ class ShardedParamV2(object):
@property
def grad(self):
return self._grad_sharded_tensor.payload
if self._grad_sharded_tensor:
return self._grad_sharded_tensor.payload
else:
return None
@grad.setter
def grad(self, t: torch.Tensor):