[shardformer] shardformer support opt models (#4091)

* [shardformer] shardformer support opt models

* [shardformer] shardformer support opt models, fix

* [shardformer] shardformer support opt models, fix

* [shardformer] shardformer support opt models, fix
This commit is contained in:
jiangmingyan
2023-06-27 17:39:29 +08:00
committed by Frank Lee
parent d33a44e8c3
commit ac80937138
6 changed files with 264 additions and 10 deletions

View File

@@ -25,7 +25,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
# switch to train mode
original_model.train()
sharded_model.train()
# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
@@ -34,5 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss
return org_output, org_loss, shard_output, shard_loss