[shardformer]fix gpt2 double head (#4663)

* [shardformer]fix gpt2 test

[shardformer]fix gpt2 test

[shardformer]fix gpt2 test

* fix

* [shardformer] add todo

* [shardformer] add todo
This commit is contained in:
flybird11111
2023-09-11 18:35:03 +08:00
committed by GitHub
parent 554aa9592e
commit eedaa3e1ef
5 changed files with 38 additions and 29 deletions

View File

@@ -141,13 +141,13 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
data = data_gen_fn()
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data['input_ids'].shape[1]
seq_len = data['input_ids'].shape[-1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len
input_shape = data['input_ids'].shape
for k, v in data.items():
if v.shape == input_shape:
data[k] = v.repeat(1, times)
data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,))
sharded_model.train()
if booster.plugin.stage_manager is not None: