shardformer fp8

This commit is contained in:
GuangyaoZhang
2024-07-08 07:04:48 +00:00
parent 51f916b11d
commit 457a0de79f
16 changed files with 520 additions and 234 deletions

View File

@@ -1137,6 +1137,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@@ -1204,6 +1205,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.sequence_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
hidden_states = self.ln_f(hidden_states)