[shardformer] polish chatglm code

This commit is contained in:
klhhhhh
2023-07-12 15:25:07 +08:00
committed by Hongxin Liu
parent 8620009dd7
commit 1a29e8fc29
3 changed files with 4 additions and 44 deletions

View File

@@ -19,6 +19,7 @@ from colossalai.testing import (
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,