[lora] add lora APIs for booster, support lora for TorchDDP (#4981)

* add apis and peft requirement

* add liscense and implement apis

* add checkpointio apis

* add torchddp fwd_bwd test

* add support_lora methods

* add checkpointio test and debug

* delete unneeded codes

* remove peft from LICENSE

* add concrete methods for enable_lora

* simplify enable_lora api

* fix requirements
This commit is contained in:
Baizhou Zhang
2023-10-31 15:19:37 +08:00
committed by Hongxin Liu
parent c1594e4bad
commit 14b0d4c7e5
11 changed files with 265 additions and 7 deletions

View File

@@ -0,0 +1,108 @@
import copy
import os
import torch
from peft import LoraConfig
from torch import distributed as dist
from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.testing import (
assert_equal,
assert_not_equal,
check_state_dict_equal,
clear_cache_before_run,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@clear_cache_before_run()
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = booster.enable_lora(model, lora_config=lora_config)
model_copy = copy.deepcopy(model)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = loss_fn
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
output = output_transform_fn(output)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
if "lora_" in n1:
# lora modules require gradients, thus updated
assert p1.requires_grad
assert_not_equal(p1.to(p2.device), p2)
else:
if not p1.requires_grad:
assert_equal(p1.to(p2.device), p2)
@clear_cache_before_run()
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
plugin = TorchDDPPlugin()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
model_save = model_fn()
model_load = copy.deepcopy(model_save)
booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, lora_config=lora_config)
model_save, _, _, _, _ = booster.boost(model_save)
with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()
# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path)
model_load, _, _, _, _ = booster.boost(model_load)
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_lora_test()
@rerun_if_address_is_in_use()
def test_torch_ddp_lora():
spawn(run_dist, 2)