mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[shardformer, pipeline] add gradient_checkpointing_ratio
and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests
This commit is contained in:
@@ -138,13 +138,25 @@ class LlamaPipelineForwards:
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
num_ckpt_layers = 0
|
||||
if self.gradient_checkpointing and self.training:
|
||||
num_ckpt_layers = end_idx - start_idx
|
||||
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
|
||||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if idx - start_idx < num_ckpt_layers:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
Reference in New Issue
Block a user