mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[shardformer] shardformer support opt models (#4091)
* [shardformer] shardformer support opt models * [shardformer] shardformer support opt models, fix * [shardformer] shardformer support opt models, fix * [shardformer] shardformer support opt models, fix
This commit is contained in:
@@ -25,7 +25,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
|
||||
# switch to train mode
|
||||
original_model.train()
|
||||
sharded_model.train()
|
||||
|
||||
# run forward
|
||||
org_output = original_model(**data)
|
||||
org_output = output_transform_fn(org_output)
|
||||
@@ -34,5 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
|
||||
shard_output = sharded_model(**data)
|
||||
shard_output = output_transform_fn(shard_output)
|
||||
shard_loss = loss_fn(shard_output)
|
||||
|
||||
return org_output, org_loss, shard_output, shard_loss
|
||||
return org_output, org_loss, shard_output, shard_loss
|
Reference in New Issue
Block a user