mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
shardformer fp8
This commit is contained in:
@@ -1137,6 +1137,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
@@ -1204,6 +1205,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
Reference in New Issue
Block a user