mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[Tensor] fix init context (#931)
* change torch.Parameter to ColoParameter * fix post assignment for init context * polish * polish
This commit is contained in:
@@ -370,16 +370,22 @@ def _run_pretrain_load():
|
||||
|
||||
dict_pretrained = {}
|
||||
dict_col = {}
|
||||
c_ref = 0
|
||||
for name, param in model_pretrained.named_parameters():
|
||||
dict_pretrained[name] = param
|
||||
c_ref += 1
|
||||
c1 = 0
|
||||
c2 = 0
|
||||
for name, param in model.colo_named_parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
c1 = c1 + 1
|
||||
c1 += 1
|
||||
else:
|
||||
c2 = c2 + 1
|
||||
c2 +=1
|
||||
dict_col[name] = param
|
||||
assert c_ref == c1
|
||||
assert c2 == 0
|
||||
if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias:
|
||||
assert model.cls.predictions.decoder.bias is model.cls.predictions.bias
|
||||
|
||||
for name, param in dict_pretrained.items():
|
||||
check_equal(param, dict_col[name])
|
||||
@@ -423,5 +429,4 @@ if __name__ == '__main__':
|
||||
# test_model_parameters()
|
||||
# test_colo_optimizer()
|
||||
# test_model()
|
||||
# _test_pretrain_load(4)
|
||||
_run_pretrain_load()
|
||||
_test_pretrain_load(4)
|
||||
|
Reference in New Issue
Block a user