mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[pipeline] rewrite bert tests and fix some bugs (#4409)
* add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretraining * add bert_for_pretraining forward and policy * fix typos * cancel warning * change the imediate output to default dict * change the default output of get_shared_params * rewrite bert test * rewrite bert test * fix some bugs * del pipeline tests * del pipeline tests * del useless print * del useless print * rewrite data repeats
This commit is contained in:
@@ -131,6 +131,8 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
|
||||
def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
|
||||
data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
|
||||
booster: Booster):
|
||||
org_model.cuda()
|
||||
sharded_model.cuda()
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
outputs = output_transform_fn(outputs)
|
||||
@@ -141,7 +143,8 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
|
||||
sharded_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
data = {
|
||||
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
k: v.to('cuda').repeat(*([4] + [1] *
|
||||
(v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
for k, v in data.items()
|
||||
}
|
||||
data_iter = iter([data])
|
||||
@@ -162,6 +165,7 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
|
||||
org_model.train()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
org_output = org_model(**data)
|
||||
|
||||
org_loss = criterion(org_output)
|
||||
org_loss.backward()
|
||||
|
||||
@@ -226,7 +230,6 @@ def check_grad(org_model: Module,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
verbose: bool = False):
|
||||
|
||||
for suffix in layer_suffix:
|
||||
org_grad = getattr_(org_model, suffix).weight.grad
|
||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||
@@ -242,7 +245,6 @@ def check_grad(org_model: Module,
|
||||
# embedding may be resized when using tensor parallel
|
||||
if shard_grad.shape[0] > org_grad.shape[0]:
|
||||
shard_grad = shard_grad[:org_grad.shape[0], :]
|
||||
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||
|
||||
|
Reference in New Issue
Block a user