[shardformer] update shardformer readme (#4689)

* [shardformer] update shardformer readme

* [shardformer] update shardformer readme

* [shardformer] update shardformer readme

* [shardformer] update shardformer readme

* [shardformer] update shardformer readme
This commit is contained in:
flybird11111
2023-09-12 15:14:24 +08:00
committed by GitHub
parent 1d454733c4
commit 8844691f4b
4 changed files with 90 additions and 72 deletions

View File

@@ -29,7 +29,8 @@ MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)
num_labels=16,
pad_token_id=2)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
@@ -73,7 +74,8 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
if provider == "shard_model":
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model).cuda()
sharded_model, _ = shard_former.optimize(model)
sharded_model = sharded_model.cuda()
fn = lambda: train(sharded_model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms