mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin * fix * fix * fix * fix
This commit is contained in:
@@ -9,7 +9,8 @@ from torch.optim import AdamW
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_checkpoint_io.utils import shared_tempdir
|
||||
@@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
|
||||
model = model_fn()
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
|
||||
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
|
||||
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)]
|
||||
test_configs = [
|
||||
{
|
||||
"lora_config": lora_config,
|
||||
@@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
|
||||
|
||||
# test fwd bwd correctness
|
||||
test_model = model_load
|
||||
if isinstance(model_load, HybridParallelModule):
|
||||
model_load = model_load.module.module
|
||||
model_copy = copy.deepcopy(model_load)
|
||||
|
||||
data = data_gen_fn()
|
||||
|
Reference in New Issue
Block a user