[Tensor] add from_pretrained support and bert pretrained test (#921)

* add from_pretrained support and test

* polish

* polish

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-05-09 16:11:47 +08:00
committed by GitHub
parent 1d625fcd36
commit c195d2814c
4 changed files with 158 additions and 20 deletions

View File

@@ -23,7 +23,7 @@ from transformers.file_utils import ModelOutput
from dataclasses import fields
def _post_init_colo(self):
def _post_init_colotensor(self):
class_fields = fields(self)
# Safety and consistency checks
if len(class_fields) == 0:
@@ -72,7 +72,7 @@ def _post_init_colo(self):
self[field.name] = v
ModelOutput.__post_init__ = _post_init_colo
ModelOutput.__post_init__ = _post_init_colotensor
# complete the hack
@@ -278,6 +278,26 @@ def test_colo_optimizer():
if i > 5:
break
def _test_pretrained():
from _utils import check_equal
from transformers import BertForMaskedLM
set_seed(1)
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model_pretrained = model_pretrained.cuda()
model = model.cuda()
dict_pretrained = {}
dict_col = {}
for name, param in model_pretrained.named_parameters():
dict_pretrained[name] = param
for name, param in model.named_parameters():
dict_col[name] = param
for name, param in dict_pretrained.items():
check_equal(param, dict_col[name])
def run_1d_row_tp(model_name: str):
# A simple net with two stacked nn.Linear
@@ -377,4 +397,5 @@ def test_model(world_size):
if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
test_model()
# test_model()
_test_pretrained()

View File

@@ -1,6 +1,6 @@
from numpy import allclose
import torch
from colossalai.tensor import ColoTensor
from colossalai.tensor import ColoTensor, ColoParameter
from copy import deepcopy
from colossalai.utils import get_current_device
@@ -16,7 +16,7 @@ def test_layernorm():
delattr(ln_op_colo, 'weight')
weight_clone = ln_op.weight.clone().detach()
weight_clone.requires_grad = True
setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone))
setattr(ln_op_colo, 'weight', ColoParameter.init_from_torch_tensor(tensor=weight_clone))
output = ln_op(input_t)
output_colo = ln_op_colo(input_t_colo)
@@ -39,8 +39,8 @@ def test_linear():
input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone()
sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
sharded_weight = ColoParameter.init_from_torch_tensor(fc_ref.weight)
sharded_bias = ColoParameter.init_from_torch_tensor(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight')