mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-06 22:22:07 +00:00
* Detached ppo (#9) * run the base * working on dist ppo * sync * detached trainer * update detached trainer. no maker update function * facing init problem * 1 maker 1 trainer detached run. but no model update * facing cuda problem * fix save functions * verified maker update * nothing * add ignore * analyize loss issue * remove some debug codes * facing 2m1t stuck issue * 2m1t verified * do not use torchrun * working on 2m2t * working on 2m2t * initialize strategy in ray actor env * facing actor's init order issue * facing ddp model update issue (need unwarp ddp) * unwrap ddp actor * checking 1m2t stuck problem * nothing * set timeout for trainer choosing. It solves the stuck problem! * delete some debug output * rename to sync with upstream * rename to sync with upstream * coati rename * nothing * I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations * experience_maker_holder performs target-revolving _send_experience() instead of length comparison. * move code to ray subfolder * working on pipeline inference * apply comments * working on pipeline strategy. in progress. * remove pipeline code. clean this branch * update remote parameters by state_dict. no test * nothing * state_dict sharding transfer * merge debug branch * gemini _unwrap_model fix * simplify code * simplify code & fix LoRALinear AttributeError * critic unwrapped state_dict --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add perfomance evaluator and fix bugs (#10) * [chat] add performance evaluator for ray * [chat] refactor debug arg * [chat] support hf config * [chat] fix generation * [chat] add 1mmt dummy example * [chat] fix gemini ckpt * split experience to send (#11) Co-authored-by: csric <richcsr256@gmail.com> * [chat] refactor trainer and maker (#12) * [chat] refactor experience maker holder * [chat] refactor model init * [chat] refactor trainer args * [chat] refactor model init * [chat] refactor trainer * [chat] refactor experience sending logic and training loop args (#13) * [chat] refactor experience send logic * [chat] refactor trainer * [chat] refactor trainer * [chat] refactor experience maker * [chat] refactor pbar * [chat] refactor example folder (#14) * [chat] support quant (#15) * [chat] add quant * [chat] add quant example * prompt example (#16) * prompt example * prompt load csv data * remove legacy try --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add mmmt dummy example and refactor experience sending (#17) * [chat] add mmmt dummy example * [chat] refactor naive strategy * [chat] fix struck problem * [chat] fix naive strategy * [chat] optimize experience maker sending logic * [chat] refactor sending assignment * [chat] refactor performance evaluator (#18) * Prompt Example & requires_grad state_dict & sharding state_dict (#19) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design --------- Co-authored-by: csric <richcsr256@gmail.com> * state_dict sending adapts to new unwrap function (#20) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design * opt benchmark * better script * nothing * [chat] strategy refactor unwrap model * [chat] strategy refactor save model * [chat] add docstr * [chat] refactor trainer save model * [chat] fix strategy typing * [chat] refactor trainer save model * [chat] update readme * [chat] fix unit test * working on lora reconstruction * state_dict sending adapts to new unwrap function * remove comments --------- Co-authored-by: csric <richcsr256@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * [chat-ray] add readme (#21) * add readme * transparent graph * add note background --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] get images from url (#22) * Refactor/chat ray (#23) * [chat] lora add todo * [chat] remove unused pipeline strategy * [chat] refactor example structure * [chat] setup ci for ray * [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24) * lora support prototype * lora support * 1mmt lora & remove useless code --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] fix test ci for ray * [chat] fix test ci requirements for ray * [chat] fix ray runtime env * [chat] fix ray runtime env * [chat] fix example ci docker args * [chat] add debug info in trainer * [chat] add nccl debug info * [chat] skip ray test * [doc] fix typo --------- Co-authored-by: csric <59389055+CsRic@users.noreply.github.com> Co-authored-by: csric <richcsr256@gmail.com>
123 lines
4.3 KiB
Python
123 lines
4.3 KiB
Python
from typing import Any, Callable, Dict, List, Optional
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from loralib.layers import LoRALayer
|
|
from coati.models.lora import LoraLinear
|
|
|
|
|
|
@dataclass
|
|
class LoRAConfig:
|
|
r: int = 0
|
|
lora_alpha: int = 1
|
|
lora_dropout: float = 0
|
|
fan_in_fan_out: bool = False
|
|
|
|
|
|
class LoRAConstructor:
|
|
'''
|
|
Tools for reconstructing a model from a remote LoRA model.
|
|
(Transfering only LoRA data costs much less!)
|
|
Usage:
|
|
Step 1 (Sender):
|
|
filter_state_dict_lora()
|
|
|
|
Step 2 (Sender, Optional):
|
|
extract_lora_config()
|
|
|
|
Step 3 (Sender):
|
|
send state_dict_lora and lora_config_dict
|
|
|
|
Step 4 (Receiver):
|
|
reconstruct_increase()
|
|
|
|
Step 5 (Receiver):
|
|
load_state_dict_increase()
|
|
|
|
'''
|
|
|
|
def __init__(self):
|
|
self.lora_config_dict = None
|
|
|
|
def register_lora_config(self, lora_config_dict: Dict[str, Any]):
|
|
self.lora_config_dict = lora_config_dict
|
|
|
|
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
|
|
'''
|
|
xxx.lora_A, xxx.lora_B -->> xxx.weight
|
|
Warning: the xxx.weight here is the increment actually.
|
|
'''
|
|
if lora_config_dict is not None:
|
|
self.register_lora_config(lora_config_dict)
|
|
|
|
state_dict_increasae = OrderedDict()
|
|
config_iter = iter(self.lora_config_dict.items())
|
|
lora_A, lora_B, layer_prefix = None, None, None
|
|
for k, v in state_dict_lora.items():
|
|
if k.rpartition('.')[-1] == 'lora_A':
|
|
lora_A = v
|
|
layer_prefix = k.rpartition('.')[0]
|
|
elif k.rpartition('.')[-1] == 'lora_B':
|
|
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
|
|
layer_prefix_2, config = next(config_iter)
|
|
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
|
|
lora_B = v
|
|
weight_data_increase = self._compute(lora_A, lora_B, config)
|
|
state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase
|
|
lora_A, lora_B, layer_prefix = None, None, None
|
|
else:
|
|
raise ValueError('unexpected key')
|
|
return state_dict_increasae
|
|
|
|
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
|
|
def T(w):
|
|
return w.T if config.fan_in_fan_out else w
|
|
if config.r > 0:
|
|
scaling = config.lora_alpha / config.r
|
|
weight_data_increase = T(lora_B @ lora_A) * scaling
|
|
return weight_data_increase
|
|
return 0
|
|
|
|
def load_state_dict_increase(self, model: nn.Module, state_dict_increasae: Dict[str, Any]):
|
|
'''
|
|
The final reconstruction step
|
|
'''
|
|
# naive approach
|
|
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increasae.items()}, strict=False)
|
|
|
|
@staticmethod
|
|
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
|
|
'''
|
|
if keep_non_lora, also return non_lora state_dict
|
|
'''
|
|
state_dict_lora = OrderedDict()
|
|
state_dict_non_lora = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
if 'lora_A' in k or 'lora_B' in k:
|
|
state_dict_lora[k] = v
|
|
elif keep_non_lora:
|
|
state_dict_non_lora[k] = v
|
|
if keep_non_lora:
|
|
return state_dict_lora, state_dict_non_lora
|
|
else:
|
|
return state_dict_lora, None
|
|
|
|
@staticmethod
|
|
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
|
|
'''
|
|
extract LoraLinear model.
|
|
return OrderedDict(): name -> LoRAConfig
|
|
'''
|
|
lora_config_dict = OrderedDict()
|
|
|
|
for name, child in model.named_modules():
|
|
if isinstance(child, LoraLinear):
|
|
lora_config_dict[name] = LoRAConfig(r=child.r,
|
|
lora_alpha=child.lora_alpha,
|
|
lora_dropout=child.lora_dropout,
|
|
fan_in_fan_out=child.fan_in_fan_out)
|
|
|
|
return lora_config_dict
|