[pipeline] rewrite bert tests and fix some bugs (#4409)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining

* add bert_for_pretraining forward and policy

* fix typos

* cancel warning

* change the imediate output to default dict

* change the default output of get_shared_params

* rewrite bert test

* rewrite bert test

* fix some bugs

* del pipeline tests

* del pipeline tests

* del useless print

* del useless print

* rewrite data repeats
This commit is contained in:
Jianghai
2023-08-11 10:32:53 +08:00
committed by Hongxin Liu
parent d2cd48e0be
commit 7596e9ae08
4 changed files with 83 additions and 154 deletions

View File

@@ -131,6 +131,8 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
booster: Booster):
org_model.cuda()
sharded_model.cuda()
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
@@ -141,7 +143,8 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
sharded_model.train()
if booster.plugin.stage_manager is not None:
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
k: v.to('cuda').repeat(*([4] + [1] *
(v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
@@ -162,6 +165,7 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
org_model.train()
data = {k: v.cuda() for k, v in data.items()}
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
@@ -226,7 +230,6 @@ def check_grad(org_model: Module,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False):
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
@@ -242,7 +245,6 @@ def check_grad(org_model: Module,
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")