mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-26 09:42:27 +00:00
merge
This commit is contained in:
1
applications/ColossalChat/.gitignore
vendored
1
applications/ColossalChat/.gitignore
vendored
@@ -151,6 +151,7 @@ examples/training_scripts/wandb
|
||||
examples/training_scripts/output
|
||||
|
||||
examples/awesome-chatgpt-prompts/
|
||||
examples/inference/round.txt
|
||||
temp/
|
||||
|
||||
# ColossalChat
|
||||
|
||||
@@ -121,7 +121,7 @@ cd $COLOSSAL_AI_ROOT
|
||||
BUILD_EXT=1 pip install .
|
||||
|
||||
# Install ColossalChat
|
||||
cd $COLOSSAL_AI_ROOT/applications/Chat
|
||||
cd $COLOSSAL_AI_ROOT/applications/ColossalChat
|
||||
pip install .
|
||||
```
|
||||
|
||||
|
||||
@@ -49,6 +49,10 @@ def tokenize_sft(
|
||||
|
||||
messages = data_point["messages"]
|
||||
template = deepcopy(conversation_template)
|
||||
|
||||
if messages[0]["from"] == "system":
|
||||
template.system_message = str(messages[0]["content"])
|
||||
messages.pop(0)
|
||||
template.messages = []
|
||||
for idx, mess in enumerate(messages):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
@@ -148,11 +152,14 @@ def tokenize_prompt(
|
||||
template = deepcopy(conversation_template)
|
||||
template.messages = []
|
||||
|
||||
if messages[0]["from"] == "system":
|
||||
template.system_message = str(messages[0]["content"])
|
||||
messages.pop(0)
|
||||
|
||||
for idx, mess in enumerate(messages):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{messages}"
|
||||
f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}"
|
||||
)
|
||||
template.append_message(mess["from"], mess["content"])
|
||||
|
||||
@@ -162,7 +169,7 @@ def tokenize_prompt(
|
||||
template.messages = template.messages[:-1]
|
||||
|
||||
# Prepare data
|
||||
prompt = template.get_prompt(length=len(template.messages) - 1, add_generation_prompt=True)
|
||||
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
|
||||
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
|
||||
|
||||
if tokenizer.bos_token_id is not None:
|
||||
@@ -225,6 +232,10 @@ def tokenize_rlhf(
|
||||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
if context[0]["from"] == "system":
|
||||
template.system_message = str(context[0]["content"])
|
||||
context.pop(0)
|
||||
|
||||
for idx, mess in enumerate(context):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
@@ -345,6 +356,10 @@ def tokenize_kto(
|
||||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
if prompt[0]["from"] == "system":
|
||||
template.system_message = str(prompt[0]["content"])
|
||||
prompt.pop(0)
|
||||
|
||||
if prompt[0].get("from", None) != "user":
|
||||
raise ValueError("conversation should start with user")
|
||||
if completion.get("from", None) != "assistant":
|
||||
@@ -377,4 +392,4 @@ def tokenize_kto(
|
||||
"label": data_point["label"],
|
||||
"input_id_decode": decoded_full_prompt,
|
||||
"completion_decode": decoded_completion,
|
||||
}
|
||||
}
|
||||
@@ -46,7 +46,10 @@ class PolicyLoss(nn.Module):
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
skip = False
|
||||
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
|
||||
if action_mask is None:
|
||||
ratio_ = (log_probs - old_log_probs).exp()
|
||||
else:
|
||||
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
|
||||
|
||||
# note that if dropout is disabled (recommanded), ratio will always be 1.
|
||||
if ratio_.mean() > self.skip_threshold:
|
||||
@@ -56,7 +59,10 @@ class PolicyLoss(nn.Module):
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
loss = -torch.min(surr1, surr2)
|
||||
loss = masked_mean(loss, action_mask)
|
||||
if action_mask is not None:
|
||||
loss = masked_mean(loss, action_mask)
|
||||
else:
|
||||
loss = loss.mean(dim=1)
|
||||
loss = loss.mean()
|
||||
return loss, skip, ratio_.max()
|
||||
|
||||
@@ -81,8 +87,10 @@ class ValueLoss(nn.Module):
|
||||
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
||||
surr1 = (values_clipped - returns) ** 2
|
||||
surr2 = (values - returns) ** 2
|
||||
loss = torch.max(surr1, surr2) / torch.sum(action_mask)
|
||||
loss = torch.sum(loss * action_mask)
|
||||
if action_mask is not None:
|
||||
loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
|
||||
else:
|
||||
loss = torch.mean(torch.max(surr1, surr2))
|
||||
return 0.5 * loss
|
||||
|
||||
|
||||
|
||||
@@ -138,6 +138,7 @@ def disable_dropout(model: torch.nn.Module):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Dropout):
|
||||
module.p = 0.0
|
||||
if model is not None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Dropout):
|
||||
module.p = 0.0
|
||||
|
||||
@@ -56,6 +56,7 @@ class DPOTrainer(SLTrainer):
|
||||
beta: float = 0.1,
|
||||
gamma: float = 0.0,
|
||||
length_normalization: bool = False,
|
||||
apply_loss_mask: bool = True,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
@@ -67,6 +68,7 @@ class DPOTrainer(SLTrainer):
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.actor_loss_fn = DpoLoss(beta, gamma)
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
@@ -135,6 +137,10 @@ class DPOTrainer(SLTrainer):
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
if not self.apply_loss_mask:
|
||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
|
||||
actor_all_logits = self.model(
|
||||
@@ -284,6 +290,9 @@ class DPOTrainer(SLTrainer):
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
if not self.apply_loss_mask:
|
||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
|
||||
@@ -347,4 +356,4 @@ class DPOTrainer(SLTrainer):
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
step_bar.close()
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
from coati.models.loss import KTOLoss
|
||||
from coati.models.utils import calc_masked_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
@@ -59,6 +59,7 @@ class KTOTrainer(SLTrainer):
|
||||
beta: float = 0.1,
|
||||
desirable_weight: float = 1.0,
|
||||
undesirable_weight: float = 1.0,
|
||||
apply_loss_mask: bool = True,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
@@ -70,6 +71,7 @@ class KTOTrainer(SLTrainer):
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
@@ -134,6 +136,10 @@ class KTOTrainer(SLTrainer):
|
||||
batch["kl_attention_mask"],
|
||||
batch["kl_loss_mask"],
|
||||
)
|
||||
if not self.apply_loss_mask:
|
||||
loss_mask = loss_mask.fill_(1.0)
|
||||
kl_loss_mask = kl_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = input_ids.size()[0]
|
||||
|
||||
# actor logits
|
||||
@@ -182,8 +188,28 @@ class KTOTrainer(SLTrainer):
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
|
||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
|
||||
chosen_reward_mean = chosen_rewards.mean()
|
||||
chosen_rewards_list = [
|
||||
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
|
||||
]
|
||||
dist.all_gather(chosen_rewards_list, chosen_reward_mean)
|
||||
rejected_reward_mean = rejected_rewards.mean()
|
||||
rejected_rewards_list = [
|
||||
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
|
||||
]
|
||||
dist.all_gather(rejected_rewards_list, rejected_reward_mean)
|
||||
chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]
|
||||
rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()]
|
||||
chosen_rewards_mean = (
|
||||
torch.stack(chosen_rewards_list).mean()
|
||||
if len(chosen_rewards_list) > 0
|
||||
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
|
||||
)
|
||||
rejected_rewards_mean = (
|
||||
torch.stack(rejected_rewards_list).mean()
|
||||
if len(rejected_rewards_list) > 0
|
||||
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
|
||||
)
|
||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
|
||||
@@ -256,6 +282,11 @@ class KTOTrainer(SLTrainer):
|
||||
batch["kl_attention_mask"],
|
||||
batch["kl_loss_mask"],
|
||||
)
|
||||
|
||||
if not self.apply_loss_mask:
|
||||
loss_mask = loss_mask.fill_(1.0)
|
||||
kl_loss_mask = kl_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = input_ids.size()[0]
|
||||
|
||||
# actor logits
|
||||
@@ -315,4 +346,4 @@ class KTOTrainer(SLTrainer):
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
step_bar.close()
|
||||
@@ -52,6 +52,7 @@ class ORPOTrainer(SLTrainer):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
lam: float = 0.1,
|
||||
apply_loss_mask: bool = True,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
@@ -67,6 +68,7 @@ class ORPOTrainer(SLTrainer):
|
||||
self.save_dir = save_dir
|
||||
self.num_train_step = 0
|
||||
self.lam = lam
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.device = get_current_device()
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
@@ -130,6 +132,11 @@ class ORPOTrainer(SLTrainer):
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
|
||||
if not self.apply_loss_mask:
|
||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
@@ -263,6 +270,11 @@ class ORPOTrainer(SLTrainer):
|
||||
batch["reject_attention_mask"],
|
||||
batch["reject_loss_mask"],
|
||||
)
|
||||
|
||||
if not self.apply_loss_mask:
|
||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
@@ -311,4 +323,4 @@ class ORPOTrainer(SLTrainer):
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
step_bar.close()
|
||||
@@ -102,6 +102,7 @@ class PPOTrainer(OLTrainer):
|
||||
sample_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
offload_inference_models: bool = True,
|
||||
apply_loss_mask: bool = True,
|
||||
accumulation_steps: int = 1,
|
||||
save_interval: int = 0,
|
||||
save_dir: str = None,
|
||||
@@ -140,6 +141,7 @@ class PPOTrainer(OLTrainer):
|
||||
self.actor_optim = actor_optim
|
||||
self.critic_optim = critic_optim
|
||||
self.save_interval = save_interval
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.coordinator = coordinator
|
||||
self.actor_save_dir = os.path.join(save_dir, "actor")
|
||||
self.critic_save_dir = os.path.join(save_dir, "critic")
|
||||
@@ -229,7 +231,10 @@ class PPOTrainer(OLTrainer):
|
||||
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
|
||||
|
||||
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
|
||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||
action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask if self.apply_loss_mask else None,
|
||||
)
|
||||
actor_loss = (1 - self.ptx_coef) * actor_loss
|
||||
if not to_skip:
|
||||
@@ -249,7 +254,10 @@ class PPOTrainer(OLTrainer):
|
||||
input_ids=experience.sequences, attention_mask=experience.attention_mask
|
||||
) # [batch size, prompt_length + response_length]
|
||||
critic_loss = self.critic_loss_fn(
|
||||
values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask
|
||||
values[:, -num_actions:],
|
||||
experience.values,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask if self.apply_loss_mask else None,
|
||||
)
|
||||
critic_loss = critic_loss * self.vf_coef
|
||||
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
|
||||
|
||||
@@ -41,6 +41,7 @@ class SFTTrainer(SLTrainer):
|
||||
lr_scheduler: _LRScheduler,
|
||||
max_epochs: int = 2,
|
||||
accumulation_steps: int = 8,
|
||||
apply_loss_mask: bool = True,
|
||||
start_epoch=0,
|
||||
save_interval: int = None,
|
||||
save_dir: str = None,
|
||||
@@ -55,6 +56,7 @@ class SFTTrainer(SLTrainer):
|
||||
self.coordinator = coordinator
|
||||
self.num_train_step = 0
|
||||
self.num_eval_step = 0
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
|
||||
def _before_fit(
|
||||
@@ -100,7 +102,11 @@ class SFTTrainer(SLTrainer):
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
batch_size = batch["input_ids"].size(0)
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
@@ -158,7 +164,11 @@ class SFTTrainer(SLTrainer):
|
||||
)
|
||||
for batch in self.eval_dataloader:
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
)
|
||||
loss_mean = all_reduce_mean(tensor=outputs.loss)
|
||||
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
|
||||
step_bar.update()
|
||||
|
||||
@@ -387,6 +387,7 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
||||
- save_dir: path to store the model checkpoints.
|
||||
- max_length: input will be padded/truncated to max_length before feeding to the model.
|
||||
- max_epochs: number of epochs to train.
|
||||
- disable_loss_mask: whether to use the loss mask to mask the loss or not. For example, in SFT, if the loss mask is disabled, the model will compute the loss across all tokens in the sequence, if the loss mask is applied, only tokens correspond to the assistant responses will contribute to the final loss.
|
||||
- batch_size: training batch size.
|
||||
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
|
||||
- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.
|
||||
@@ -461,26 +462,24 @@ Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of
|
||||
|
||||
|
||||
#### Step 1: Data Collection
|
||||
The first step in Stage 1 is to collect a dataset of human demonstrations of the following format.
|
||||
The first step in Stage 1 is to collect a dataset of human demonstrations of the following JSONL format.
|
||||
|
||||
|
||||
```json
|
||||
[
|
||||
{"messages":
|
||||
[
|
||||
{
|
||||
"from": "user",
|
||||
"content": "what are some pranks with a pen i can do?"
|
||||
},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Are you looking for practical joke ideas?"
|
||||
},
|
||||
...
|
||||
]
|
||||
{"messages":
|
||||
[
|
||||
{
|
||||
"from": "user",
|
||||
"content": "what are some pranks with a pen i can do?"
|
||||
},
|
||||
{
|
||||
"from": "assistant",
|
||||
"content": "Are you looking for practical joke ideas?"
|
||||
},
|
||||
...
|
||||
]
|
||||
]
|
||||
},
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
@@ -904,4 +903,4 @@ For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/mai
|
||||
## Attention
|
||||
|
||||
|
||||
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
|
||||
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
|
||||
@@ -53,8 +53,8 @@ def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs
|
||||
tuple: A tuple containing the loaded model and tokenizer.
|
||||
"""
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs, trust_remote_code=True).to(torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.to(device)
|
||||
|
||||
@@ -151,7 +151,6 @@ def main(args):
|
||||
chat_io.prompt_for_output("assistant")
|
||||
|
||||
prompt = conv.get_prompt(add_generation_prompt=True)
|
||||
print(prompt + "<end_of_prompt>")
|
||||
input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)["input_ids"].to(
|
||||
torch.cuda.current_device()
|
||||
)
|
||||
|
||||
@@ -278,6 +278,10 @@ def train(args):
|
||||
beta=args.beta,
|
||||
gamma=args.gamma,
|
||||
length_normalization=args.length_normalization,
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
>>>>>>> main
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
@@ -346,6 +350,10 @@ if __name__ == "__main__":
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
>>>>>>> main
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
|
||||
@@ -297,6 +297,7 @@ def train(args):
|
||||
beta=args.beta,
|
||||
desirable_weight=args.desirable_weight,
|
||||
undesirable_weight=args.undesirable_weight,
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
@@ -341,6 +342,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in KTO loss")
|
||||
parser.add_argument("--desirable_weight", type=float, default=1.0, help="desirable_weight in KTO loss")
|
||||
parser.add_argument("--undesirable_weight", type=float, default=1.0, help="undesirable_weight in KTO loss")
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
@@ -373,4 +375,4 @@ if __name__ == "__main__":
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
||||
train(args)
|
||||
@@ -259,6 +259,7 @@ def train(args):
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
lam=args.lam,
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
@@ -301,6 +302,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--lam", type=float, default=0.1, help="lambda in ORPO loss")
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
@@ -338,4 +340,4 @@ if __name__ == "__main__":
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
train(args)
|
||||
train(args)
|
||||
@@ -411,6 +411,7 @@ def train(args):
|
||||
use_cache=True,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
save_dir=args.save_path,
|
||||
save_interval=args.save_interval,
|
||||
@@ -498,9 +499,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--critic_lr", type=float, default=9e-6)
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.0)
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--max_length", type=int, default=2048)
|
||||
parser.add_argument("--max_seq_len", type=int, default=256)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
|
||||
@@ -272,6 +272,7 @@ def train(args):
|
||||
lr_scheduler=lr_scheduler,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
start_epoch=start_epoch,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_path,
|
||||
@@ -317,6 +318,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
|
||||
@@ -2,7 +2,7 @@ transformers==4.39.3
|
||||
tqdm
|
||||
datasets==2.14.7
|
||||
loralib
|
||||
colossalai==0.4.0
|
||||
colossalai>=0.4.0
|
||||
torch>=2.1.0
|
||||
langchain
|
||||
tokenizers
|
||||
@@ -20,4 +20,4 @@ datasets
|
||||
ninja==1.11.1
|
||||
sentencepiece==0.1.99
|
||||
flash-attn
|
||||
tiktoken
|
||||
tiktoken
|
||||
@@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
set -xu
|
||||
|
||||
@@ -119,11 +119,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
tp='2'
|
||||
bs='8'
|
||||
fi
|
||||
if [[ $plugin == "tp_zero2" ]]; then
|
||||
tp='4'
|
||||
tp='2'
|
||||
bs='8'
|
||||
zero_stage='2'
|
||||
plugin='3d'
|
||||
@@ -136,13 +136,13 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
fi
|
||||
if [[ $plugin == "pp" ]]; then
|
||||
bs='8'
|
||||
pp='4'
|
||||
pp='2'
|
||||
plugin='3d'
|
||||
fi
|
||||
if [[ $plugin == "sp_split_gather" ]]; then
|
||||
enable_sequence_parallelism='--enable_sequence_parallelism'
|
||||
sp_mode='split_gather'
|
||||
tp='4'
|
||||
tp='2'
|
||||
sp='1'
|
||||
bs='8'
|
||||
plugin='3d'
|
||||
@@ -150,7 +150,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
if [[ $plugin == "sp_ring" ]]; then
|
||||
enable_sequence_parallelism='--enable_sequence_parallelism'
|
||||
sp_mode='ring'
|
||||
tp='4'
|
||||
tp='2'
|
||||
sp='1'
|
||||
bs='8'
|
||||
plugin='3d'
|
||||
@@ -159,7 +159,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
enable_sequence_parallelism='--enable_sequence_parallelism'
|
||||
sp_mode='all_to_all'
|
||||
tp='1'
|
||||
sp='4'
|
||||
sp='2'
|
||||
bs='8'
|
||||
plugin='3d'
|
||||
fi
|
||||
@@ -175,7 +175,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
@@ -242,7 +242,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
tp='2'
|
||||
bs='8'
|
||||
fi
|
||||
grad_accu='2'
|
||||
@@ -256,7 +256,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_rm.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
@@ -325,7 +325,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
lora_config=""
|
||||
fi
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
tp='2'
|
||||
bs='16'
|
||||
ebs='32'
|
||||
fi
|
||||
@@ -350,7 +350,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
|
||||
--pretrain $pretrain \
|
||||
--rm_pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
@@ -417,7 +417,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
tp='2'
|
||||
bs='8'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
@@ -442,7 +442,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_dpo.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
@@ -500,7 +500,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
tp='2'
|
||||
bs='8'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
@@ -525,7 +525,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
@@ -583,7 +583,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
tp='1'
|
||||
bs='2'
|
||||
if [[ $plugin == "3d" ]]; then
|
||||
tp='4'
|
||||
tp='2'
|
||||
bs='8'
|
||||
fi
|
||||
if [[ $plugin == "zero2" ]]; then
|
||||
@@ -608,7 +608,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_kto/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_kto.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
@@ -640,4 +640,4 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||
fi
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -14,9 +14,9 @@ This directory contains the applications that are powered by Colossal-AI.
|
||||
The list of applications include:
|
||||
|
||||
- [X] [Open-Sora](https://github.com/hpcaitech/Open-Sora): Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models
|
||||
- [X] [ColossalChat](./ColossalChat/): Replication of ChatGPT with RLHF.
|
||||
- [X] [Colossal-LLaMA](./Colossal-LLaMA/): Continual Pre-training and Supervisied Fine-tuning of LLaMA2 / LLaMA3.
|
||||
- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs.
|
||||
- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF.
|
||||
- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.
|
||||
- [X] [ColossalQA](./ColossalQA/README.md): Document Retrieval Conversation System
|
||||
- [X] [SwiftInfer](https://github.com/hpcaitech/SwiftInfer): Breaks the Length Limit of LLM Inference for Multi-Round Conversations
|
||||
|
||||
Reference in New Issue
Block a user