[Tensor] fix init context (#931)

* change torch.Parameter to ColoParameter

* fix post assignment for init context

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-05-11 15:48:12 +08:00
committed by GitHub
parent dfc88b85ea
commit d73c2b1d79
2 changed files with 50 additions and 15 deletions

View File

@@ -370,16 +370,22 @@ def _run_pretrain_load():
dict_pretrained = {}
dict_col = {}
c_ref = 0
for name, param in model_pretrained.named_parameters():
dict_pretrained[name] = param
c_ref += 1
c1 = 0
c2 = 0
for name, param in model.colo_named_parameters():
if isinstance(param, ColoParameter):
c1 = c1 + 1
c1 += 1
else:
c2 = c2 + 1
c2 +=1
dict_col[name] = param
assert c_ref == c1
assert c2 == 0
if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias:
assert model.cls.predictions.decoder.bias is model.cls.predictions.bias
for name, param in dict_pretrained.items():
check_equal(param, dict_col[name])
@@ -423,5 +429,4 @@ if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
# test_model()
# _test_pretrain_load(4)
_run_pretrain_load()
_test_pretrain_load(4)