mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[shardformer, pipeline] add gradient_checkpointing_ratio and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests
This commit is contained in:
@@ -49,9 +49,9 @@ if HAS_LLAMA:
|
||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||
|
||||
config = LlamaConfig(
|
||||
num_hidden_layers=4,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_hidden_layers=8,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16,
|
||||
|
||||
@@ -1,4 +1,23 @@
|
||||
import random
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
||||
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||
|
||||
|
||||
class _ShardConfig(ShardConfig):
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class _PipelineStageManager(PipelineStageManager):
|
||||
def __init__(self):
|
||||
self.is_interleave = False
|
||||
self.num_layers_per_stage = None
|
||||
|
||||
@property
|
||||
def num_stages(self):
|
||||
return random.randint(5, 10)
|
||||
|
||||
|
||||
def test_t5_pipeline_distribution():
|
||||
@@ -10,7 +29,10 @@ def test_t5_pipeline_distribution():
|
||||
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
|
||||
}
|
||||
|
||||
stage_manager = _PipelineStageManager()
|
||||
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||
policy = T5BasePolicy()
|
||||
policy.set_shard_config(shard_config)
|
||||
for i in range(num_test_cases):
|
||||
_, decoder_starting_stage = policy.distribute_t5_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
@@ -35,7 +57,10 @@ def test_t5_pipeline_layers():
|
||||
}
|
||||
|
||||
for i in range(num_test_cases):
|
||||
stage_manager = _PipelineStageManager()
|
||||
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||
policy = T5BasePolicy()
|
||||
policy.set_shard_config(shard_config)
|
||||
layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
test_dict["num_decoder_layers"][i],
|
||||
|
||||
@@ -1,4 +1,23 @@
|
||||
import random
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.whisper import WhisperPolicy
|
||||
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||
|
||||
|
||||
class _ShardConfig(ShardConfig):
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class _PipelineStageManager(PipelineStageManager):
|
||||
def __init__(self):
|
||||
self.is_interleave = False
|
||||
self.num_layers_per_stage = None
|
||||
|
||||
@property
|
||||
def num_stages(self):
|
||||
return random.randint(5, 10)
|
||||
|
||||
|
||||
def test_whisper_pipeline_distribution():
|
||||
@@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution():
|
||||
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
|
||||
}
|
||||
|
||||
stage_manager = _PipelineStageManager()
|
||||
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||
policy = WhisperPolicy()
|
||||
policy.set_shard_config(shard_config)
|
||||
for i in range(num_test_cases):
|
||||
_, decoder_starting_stage = policy.distribute_whisper_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
@@ -34,7 +56,10 @@ def test_whisper_pipeline_layers():
|
||||
],
|
||||
}
|
||||
|
||||
stage_manager = _PipelineStageManager()
|
||||
shard_config = _ShardConfig(pipeline_stage_manager=stage_manager)
|
||||
policy = WhisperPolicy()
|
||||
policy.set_shard_config(shard_config)
|
||||
for i in range(num_test_cases):
|
||||
layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
@@ -24,9 +25,13 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
|
||||
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config
|
||||
)
|
||||
if enable_gradient_checkpointing:
|
||||
org_model.gradient_checkpointing_enable()
|
||||
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
@@ -101,6 +106,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5),
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
@@ -108,6 +115,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"num_microbatches": 4,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
|
||||
),
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
@@ -189,6 +200,13 @@ def run_llama_test(test_config):
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_stages=2,
|
||||
num_model_chunks=2,
|
||||
num_model_layers=8,
|
||||
num_ckpt_layers_per_stage=[0, 1, 2, 2],
|
||||
),
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user