mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[shardformer]fix gpt2 double head (#4663)
* [shardformer]fix gpt2 test [shardformer]fix gpt2 test [shardformer]fix gpt2 test * fix * [shardformer] add todo * [shardformer] add todo
This commit is contained in:
@@ -141,13 +141,13 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
|
||||
data = data_gen_fn()
|
||||
|
||||
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
|
||||
seq_len = data['input_ids'].shape[1]
|
||||
seq_len = data['input_ids'].shape[-1]
|
||||
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
|
||||
times = lcm // seq_len
|
||||
input_shape = data['input_ids'].shape
|
||||
for k, v in data.items():
|
||||
if v.shape == input_shape:
|
||||
data[k] = v.repeat(1, times)
|
||||
data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,))
|
||||
|
||||
sharded_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
|
Reference in New Issue
Block a user