mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[LowLevelZero] low level zero support lora (#5153)
* low level zero support lora low level zero support lora * add checkpoint test * add checkpoint test * fix * fix * fix * fix fix fix fix * fix * fix fix fix fix fix fix fix * fix * fix fix fix fix fix fix fix * fix * test ci * git # This is a combination of 3 commits. Update low_level_zero_plugin.py Update low_level_zero_plugin.py fix fix fix * fix naming fix naming fix naming fix
This commit is contained in:
committed by
Hongxin Liu
parent
14b0d4c7e5
commit
8954a0c2e2
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Iterator, List, Tuple, Union
|
||||
from typing import Callable, Dict, Iterator, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -51,6 +51,12 @@ class DPPluginWrapper(DPPluginBase):
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
pass
|
||||
|
||||
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
|
||||
pass
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
def check_dataloader_sharding():
|
||||
plugin = DPPluginWrapper()
|
||||
|
@@ -2,6 +2,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from peft import LoraConfig
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
@@ -22,13 +23,17 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]:
|
||||
device = get_accelerator().get_current_device()
|
||||
try:
|
||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
|
||||
booster = Booster(plugin=plugin)
|
||||
model = model_fn()
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
|
||||
if lora_config is not None:
|
||||
model = booster.enable_lora(model, lora_config=lora_config)
|
||||
|
||||
criterion = lambda x: x.mean()
|
||||
data = data_gen_fn()
|
||||
|
||||
@@ -48,6 +53,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||
|
||||
except Exception as e:
|
||||
return repr(e)
|
||||
# raise e
|
||||
|
||||
|
||||
@parameterize("stage", [2])
|
||||
@@ -91,10 +97,42 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||
|
||||
|
||||
@parameterize("stage", [2])
|
||||
@parameterize("model_name", ["transformers_llama"])
|
||||
def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry(model_name)
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
task_type = None
|
||||
if name == "transformers_llama_for_casual_lm":
|
||||
task_type = "CAUSAL_LM"
|
||||
if name == "transformers_llama_for_sequence_classification":
|
||||
task_type = "SEQ_CLS"
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if err is None:
|
||||
passed_models.append(name)
|
||||
else:
|
||||
failed_info[name] = err
|
||||
if early_stop:
|
||||
break
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
|
||||
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_low_level_zero_plugin(early_stop=early_stop)
|
||||
check_low_level_zero_lora(early_stop=early_stop)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
@@ -1,5 +1,9 @@
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from peft import LoraConfig
|
||||
from torchvision.models import resnet18
|
||||
from utils import shared_tempdir
|
||||
|
||||
@@ -15,6 +19,7 @@ from colossalai.testing import (
|
||||
spawn,
|
||||
)
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
# stage 1 and 2 process the optimizer/mode the same way
|
||||
@@ -69,9 +74,107 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]:
|
||||
try:
|
||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload)
|
||||
new_plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload)
|
||||
booster = Booster(plugin=plugin)
|
||||
new_booster = Booster(plugin=new_plugin)
|
||||
model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
new_model = deepcopy(model)
|
||||
new_optimizer = HybridAdam(new_model.parameters(), lr=1e-3)
|
||||
model = booster.enable_lora(model, lora_config=lora_config)
|
||||
criterion = lambda x: x.mean()
|
||||
data = data_gen_fn()
|
||||
|
||||
data = {
|
||||
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
|
||||
}
|
||||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
||||
booster.save_lora_as_pretrained(model, model_ckpt_path)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False)
|
||||
new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config)
|
||||
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
|
||||
# check master weight
|
||||
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||
working_param_id_set = set(id(p) for p in new_model.parameters())
|
||||
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
|
||||
assert p_id in working_param_id_set
|
||||
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
|
||||
padding = new_optimizer._param_store.get_param_padding_size(working_param)
|
||||
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
|
||||
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
|
||||
assert torch.equal(
|
||||
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
|
||||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
||||
|
||||
except Exception as e:
|
||||
# return repr(e)
|
||||
raise e
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("stage", [2])
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("offload", [False, True])
|
||||
@parameterize("model_name", ["transformers_llama"])
|
||||
def check_low_level_zero_lora_checkpointIO(
|
||||
stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True
|
||||
):
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry(model_name)
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name != "transformers_llama":
|
||||
continue
|
||||
task_type = None
|
||||
if name == "transformers_llama_for_casual_lm":
|
||||
task_type = "CAUSAL_LM"
|
||||
if name == "transformers_llama_for_sequence_classification":
|
||||
task_type = "SEQ_CLS"
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
err = run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if err is None:
|
||||
passed_models.append(name)
|
||||
else:
|
||||
failed_info[name] = err
|
||||
if early_stop:
|
||||
break
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
|
||||
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_low_level_zero_checkpointIO()
|
||||
check_low_level_zero_lora_checkpointIO()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user