mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[shardformer] update shardformer readme (#4689)
* [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme * [shardformer] update shardformer readme
This commit is contained in:
@@ -29,7 +29,8 @@ MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16)
|
||||
num_labels=16,
|
||||
pad_token_id=2)
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
|
||||
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
|
||||
|
||||
@@ -73,7 +74,8 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
|
||||
if provider == "shard_model":
|
||||
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model = shard_former.optimize(model).cuda()
|
||||
sharded_model, _ = shard_former.optimize(model)
|
||||
sharded_model = sharded_model.cuda()
|
||||
fn = lambda: train(sharded_model, data)
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
Reference in New Issue
Block a user