mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
[feat] GRPO with distributed implementation (#6230)
* add reward related function * add simple grpo * update grpo * polish * modify data loader * grpo consumer * update loss * update reward fn * update example * update loader * add algo selection * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add save * update select algo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update grpo * update reward fn * update reward * fix reward score * add response length * detach * fix tp bug * fix consumer * convert to 8 generation * print results * setup update * fix transformers backend * [Feature] Support Distributed LogProb for GRPO Training (#6247) * [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * fix vllm * fix logprob, add filtering, temperature annealing, lr descent * simplify vllm preprocessing input ids * update logging * [feat] add microbatch forwarding (#6251) * add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Distributed RLHF] Integration of PP (#6257) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * [hot-fix] Fix memory leakage bug, support TP+PP (#6258) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
This commit is contained in:
parent
2bb71c6248
commit
7bb7e80476
2
.gitignore
vendored
2
.gitignore
vendored
@ -163,3 +163,5 @@ coverage.xml
|
||||
# log, test files - ColossalChat
|
||||
applications/ColossalChat/logs
|
||||
applications/ColossalChat/tests/logs
|
||||
applications/ColossalChat/wandb
|
||||
applications/ColossalChat/model
|
||||
|
@ -356,10 +356,24 @@ def apply_chat_template_and_mask(
|
||||
truncation: bool = True,
|
||||
ignore_idx: int = -100,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
|
||||
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
|
||||
|
||||
system_element = {
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}
|
||||
|
||||
# Format for RL.
|
||||
gt_answer = None
|
||||
if "messages" in chat and "gt_answer" in chat:
|
||||
gt_answer = chat["gt_answer"]
|
||||
chat = [chat["messages"]]
|
||||
|
||||
tokens = []
|
||||
assistant_mask = []
|
||||
for i, msg in enumerate(chat):
|
||||
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
|
||||
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
|
||||
# remove unexpected bos token
|
||||
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
|
||||
msg_tokens = msg_tokens[1:]
|
||||
@ -372,14 +386,10 @@ def apply_chat_template_and_mask(
|
||||
if max_length is not None:
|
||||
if padding and len(tokens) < max_length:
|
||||
to_pad = max_length - len(tokens)
|
||||
if tokenizer.padding_side == "right":
|
||||
tokens.extend([tokenizer.pad_token_id] * to_pad)
|
||||
assistant_mask.extend([False] * to_pad)
|
||||
attention_mask.extend([0] * to_pad)
|
||||
else:
|
||||
tokens = [tokenizer.pad_token_id] * to_pad + tokens
|
||||
assistant_mask = [False] * to_pad + assistant_mask
|
||||
attention_mask = [0] * to_pad + attention_mask
|
||||
# Left padding for generation.
|
||||
tokens = [tokenizer.pad_token_id] * to_pad + tokens
|
||||
assistant_mask = [False] * to_pad + assistant_mask
|
||||
attention_mask = [0] * to_pad + attention_mask
|
||||
if truncation and len(tokens) > max_length:
|
||||
tokens = tokens[:max_length]
|
||||
assistant_mask = assistant_mask[:max_length]
|
||||
@ -389,6 +399,13 @@ def apply_chat_template_and_mask(
|
||||
labels = input_ids.clone()
|
||||
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
|
||||
|
||||
if gt_answer is not None:
|
||||
gt_answer = tokenizer.encode(
|
||||
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
|
||||
)
|
||||
gt_answer = gt_answer.squeeze(1)
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@ -33,6 +34,8 @@ class BaseConsumer:
|
||||
model_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
microbatch_size: int = 1,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
):
|
||||
self.num_producers = num_producers
|
||||
self.num_episodes = num_episodes
|
||||
@ -44,14 +47,16 @@ class BaseConsumer:
|
||||
self.num_recv_per_update = num_recv_per_update
|
||||
self.batch_size = batch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
self.save_interval = save_interval
|
||||
self.save_dir = save_dir
|
||||
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
|
||||
|
||||
self.device = get_current_device()
|
||||
self.lr_scheduler = None
|
||||
|
||||
def setup(self) -> None:
|
||||
for i in range(self.num_producers):
|
||||
@ -60,18 +65,15 @@ class BaseConsumer:
|
||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||
|
||||
plugin_config = dict(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
precision="bf16",
|
||||
zero_stage=1,
|
||||
)
|
||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
||||
plugin_config["microbatch_size"] = self.microbatch_size
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||
|
||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||
|
||||
self.buffer = []
|
||||
@ -94,7 +96,6 @@ class BaseConsumer:
|
||||
i = 0
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# receive data from producers
|
||||
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||
self.buffer.extend(
|
||||
@ -116,13 +117,26 @@ class BaseConsumer:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
assert len(self.buffer) == 0
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
if (step + 1) % self.save_interval == 0:
|
||||
if self.rank == 0:
|
||||
print(f"Start saving policy model at step {step + 1}.")
|
||||
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
|
||||
self.booster.save_model(self.policy_model, save_path, shard=True)
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
|
||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
529
applications/ColossalChat/coati/distributed/grpo_consumer.py
Normal file
529
applications/ColossalChat/coati/distributed/grpo_consumer.py
Normal file
@ -0,0 +1,529 @@
|
||||
import json
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from coati.distributed.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
@ray.remote
|
||||
class GRPOConsumer(BaseConsumer):
|
||||
def __init__(
|
||||
self,
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
num_update_per_episode,
|
||||
num_recv_per_update,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size=1,
|
||||
num_generations=8,
|
||||
use_wandb=True,
|
||||
generate_config=None,
|
||||
training_config={},
|
||||
project_name=None,
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
num_update_per_episode,
|
||||
num_recv_per_update,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.policy_model.train()
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_format_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_acc_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||
self.accum_count = 0
|
||||
self.generate_config = generate_config
|
||||
self.training_config = training_config
|
||||
self.project_name = project_name
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.reference_model.eval()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.num_generations = num_generations
|
||||
self.filter_range = training_config.get("filter_range", None)
|
||||
if self.filter_range is not None:
|
||||
assert len(self.filter_range) == 2, "Filter range should have 2 values."
|
||||
|
||||
# Initialize verifiable reward.
|
||||
response_format_tags = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
||||
)
|
||||
|
||||
self.policy_loss_fn = PolicyLoss()
|
||||
self.global_step = 0
|
||||
self.use_wandb = use_wandb
|
||||
|
||||
self.lr_scheduler = CosineAnnealingWarmupLR(
|
||||
optimizer=self.optimizer,
|
||||
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
|
||||
warmup_steps=0,
|
||||
eta_min=0.1 * training_config.get("lr", 1e-6),
|
||||
)
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
if self.use_wandb and (
|
||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
||||
):
|
||||
# Initialize wandb.
|
||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
|
||||
|
||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||
)
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
self.plugin.logger.set_level("ERROR")
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
"""
|
||||
Step data from policy model:
|
||||
[{
|
||||
"input_ids": torch.Tensor,
|
||||
"attention_mask": torch.Tensor,
|
||||
"action_mask": torch.Tensor,
|
||||
"action_log_probs": torch.Tensor,
|
||||
},
|
||||
...]
|
||||
Format:
|
||||
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||
"""
|
||||
|
||||
# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
|
||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
||||
action_mask = data["action_mask"]
|
||||
num_action = action_mask.shape[1]
|
||||
old_action_log_probs = data["action_log_probs"]
|
||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
||||
ctx = (
|
||||
nullcontext()
|
||||
if need_update or self.booster.plugin.zero_stage == 2
|
||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
)
|
||||
with ctx:
|
||||
reward_group = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
|
||||
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
|
||||
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
|
||||
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
||||
|
||||
# [batch_size, num_generations]
|
||||
|
||||
group_reward = reward.view(-1, self.num_generations)
|
||||
reward_mean = group_reward.mean(dim=1)
|
||||
# [batch_size x num_generations]
|
||||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
# [batch_size x num_generations]
|
||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||
loss_mask = (
|
||||
None
|
||||
if self.filter_range is None
|
||||
else torch.logical_and(
|
||||
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
|
||||
).repeat_interleave(self.num_generations, dim=0)
|
||||
)
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
action_mask_forward_micro_batch = action_mask[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
loss_mask_forward_micro_batch = (
|
||||
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
|
||||
if loss_mask is not None
|
||||
else None
|
||||
)
|
||||
advantages_forward_micro_batch = advantages[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
]
|
||||
|
||||
if self.plugin.pp_size > 1:
|
||||
# Support training with PP.
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_outputs = self.booster.execute_pipeline(
|
||||
iter(
|
||||
[
|
||||
{
|
||||
"input_ids": input_ids_forward_micro_batch,
|
||||
"attention_mask": attention_mask_forward_micro_batch,
|
||||
}
|
||||
]
|
||||
),
|
||||
self.reference_model,
|
||||
criterion=lambda outputs, inputs: torch.tensor(
|
||||
[0.0], device=action_mask.device
|
||||
), # dummy criterion
|
||||
optimizer=None,
|
||||
return_loss=False,
|
||||
return_outputs=True,
|
||||
)
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
reference_model_logits = reference_model_outputs["outputs"]["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
else:
|
||||
# Dummy reference logprobs for data iterator.
|
||||
reference_action_log_probs = None
|
||||
|
||||
data_policy_forward = {
|
||||
"input_ids": input_ids_forward_micro_batch,
|
||||
"attention_mask": attention_mask_forward_micro_batch,
|
||||
"action_mask": action_mask_forward_micro_batch,
|
||||
"reference_action_log_probs": reference_action_log_probs,
|
||||
"advantages": advantages_forward_micro_batch,
|
||||
"loss_mask": loss_mask_forward_micro_batch,
|
||||
"source": self.rank,
|
||||
}
|
||||
|
||||
kl = []
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
action_log_probs = calc_action_log_probs(
|
||||
action_logits / self.generate_config["temperature"],
|
||||
inputs["input_ids"],
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- (inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
|
||||
inputs["action_mask"], dim=-1
|
||||
)
|
||||
kl.append(appox_kl.mean())
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
action_log_probs,
|
||||
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
inputs["action_mask"],
|
||||
loss_mask=inputs["loss_mask"],
|
||||
)
|
||||
return loss
|
||||
|
||||
policy_model_outputs = self.booster.execute_pipeline(
|
||||
iter([data_policy_forward]),
|
||||
self.policy_model,
|
||||
criterion=_criterion,
|
||||
optimizer=self.optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
loss = policy_model_outputs["loss"]
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
if len(kl) > 0:
|
||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
|
||||
mean_kl.append(kl)
|
||||
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
||||
else:
|
||||
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
action_log_probs = calc_action_log_probs(
|
||||
policy_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||
action_mask_forward_micro_batch, dim=-1
|
||||
)
|
||||
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
loss_mask=loss_mask_forward_micro_batch,
|
||||
)
|
||||
|
||||
if not skip_update:
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
# Calculate accumulate value.
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
self.accum_reward.add_(reward.data)
|
||||
self.accum_format_reward.add_(format_reward.data)
|
||||
self.accum_acc_reward.add_(acc_reward.data)
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
self.accum_response_length.add_(response_length.data)
|
||||
self.accum_count += 1
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
print(
|
||||
"Loss:",
|
||||
self.accum_loss.item() / self.accum_count,
|
||||
"\nReward:",
|
||||
self.accum_reward.item() / self.accum_count,
|
||||
"\nFormat Reward:",
|
||||
self.accum_format_reward.item() / self.accum_count,
|
||||
"\nAcc Reward:",
|
||||
self.accum_acc_reward.item() / self.accum_count,
|
||||
"\nKL:",
|
||||
self.accum_kl.item() / self.accum_count,
|
||||
"\nAdvantages:",
|
||||
self.accum_advantages.item() / self.accum_count,
|
||||
"\nResponse Length:",
|
||||
self.accum_response_length.item() / self.accum_count,
|
||||
)
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
||||
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
|
||||
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_acc_reward.zero_()
|
||||
self.accum_format_reward.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
model = self.policy_model.unwrap()
|
||||
state_dict = model.state_dict()
|
||||
return state_dict
|
||||
|
||||
|
||||
@ray.remote
|
||||
class GRPOEvalConsumer(BaseConsumer):
|
||||
def __init__(
|
||||
self,
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
num_update_per_episode,
|
||||
num_recv_per_update,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size=1,
|
||||
num_generations=4,
|
||||
use_wandb=True,
|
||||
log_dir="./results",
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
num_update_per_episode,
|
||||
num_recv_per_update,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.policy_model.train()
|
||||
self.accum_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_format_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_acc_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||
self.accum_count = torch.zeros(1, device=self.device)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.num_generations = num_generations
|
||||
|
||||
# Initialize verifiable reward.
|
||||
response_format_tags = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
||||
)
|
||||
|
||||
self.log_dir = log_dir
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
else:
|
||||
os.system(f"rm -rf {self.log_dir}/*")
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
rank = dist.get_rank()
|
||||
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
|
||||
kwargs["input_ids"].size(0)
|
||||
reward_group = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
reward = [value[0].item() for value in reward_group]
|
||||
format_reward = [value[1].item() for value in reward_group]
|
||||
acc_reward = [value[2].item() for value in reward_group]
|
||||
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
|
||||
|
||||
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
|
||||
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
|
||||
for i in range(len(response)):
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"response": response[i],
|
||||
"reward": reward[i],
|
||||
"format_reward": format_reward[i],
|
||||
"acc_reward": acc_reward[i],
|
||||
"response_length": response_length[i],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
self.accum_reward += sum(reward)
|
||||
self.accum_format_reward += sum(format_reward)
|
||||
self.accum_acc_reward += sum(acc_reward)
|
||||
self.accum_response_length += sum(response_length)
|
||||
self.accum_count += len(reward)
|
||||
|
||||
# print results
|
||||
total_count = all_reduce_mean(self.accum_count, self.plugin)
|
||||
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
|
||||
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
|
||||
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
|
||||
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
|
||||
)
|
||||
return None
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
model = self.policy_model.unwrap()
|
||||
state_dict = model.state_dict()
|
||||
return state_dict
|
@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
)
|
||||
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
|
||||
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
):
|
||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||
model_config.update(self.FORCE_MODEL_CONFIG)
|
||||
path = model_config.pop("path")
|
||||
@ -61,12 +67,22 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
self.generate_config = generate_config.copy()
|
||||
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
micro_batch_size = input_ids.size(0)
|
||||
input_ids = input_ids.to(get_current_device())
|
||||
attention_mask = attention_mask.to(get_current_device())
|
||||
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
|
||||
gt_answer = None
|
||||
if "gt_answer" in kwargs:
|
||||
gt_answer = kwargs.pop("gt_answer")
|
||||
if self.num_generations > 1:
|
||||
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
|
||||
out = self.model.generate(
|
||||
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
|
||||
)
|
||||
input_len = input_ids.shape[-1]
|
||||
new_token_ids = out.sequences[:, input_len:]
|
||||
# get log probs
|
||||
@ -76,10 +92,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
|
||||
action_log_probs = torch.cat(action_log_probs, dim=1)
|
||||
# get action mask
|
||||
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
|
||||
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
|
||||
if self.tokenizer.eos_token_id is not None:
|
||||
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
|
||||
action_mask[indices[0], indices[1] + 1 :] = 0
|
||||
response_idx[:, 0] = input_len
|
||||
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
|
||||
|
||||
if attention_mask.size(0) != action_mask.size(0):
|
||||
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||
@ -91,7 +110,15 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
"attention_mask": attention_mask,
|
||||
"action_log_probs": action_log_probs,
|
||||
"action_mask": action_mask,
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
|
||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
||||
|
||||
if gt_answer is not None:
|
||||
# repeat gt_answer for each prompt.
|
||||
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
@ -99,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
|
||||
|
||||
class SGLangInferenceBackend(BaseInferenceBackend):
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
):
|
||||
if sgl is None:
|
||||
raise ImportError("sglang is not installed")
|
||||
path = model_config.pop("path")
|
||||
@ -156,29 +189,46 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
logprobs=0,
|
||||
)
|
||||
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
):
|
||||
if LLM is None:
|
||||
raise ImportError("vllm is not installed")
|
||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||
path = model_config.pop("path")
|
||||
self.llm = LLM(path, **model_config)
|
||||
self.llm = LLM(model=path, **model_config)
|
||||
generate_config = generate_config.copy()
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
generate_config.update({"n": num_generations})
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
micro_batch_size = input_ids.size(0)
|
||||
response_start_idx = input_ids.size(1)
|
||||
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
|
||||
micro_batch_input_ids = input_ids.tolist()
|
||||
micro_batch_input_ids_no_padding = [
|
||||
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
|
||||
]
|
||||
outputs = self.llm.generate(
|
||||
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
|
||||
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
|
||||
)
|
||||
out_tokens = []
|
||||
out_len = []
|
||||
log_probs = []
|
||||
response_idx = []
|
||||
for out in outputs:
|
||||
for output_i in out.outputs:
|
||||
out_len.append(len(output_i.token_ids))
|
||||
out_tokens.append(list(output_i.token_ids))
|
||||
response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
|
||||
assert len(output_i.logprobs) == len(output_i.token_ids)
|
||||
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
|
||||
log_probs.append(p)
|
||||
@ -195,6 +245,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
|
||||
out_tokens = torch.tensor(out_tokens)
|
||||
log_probs = torch.tensor(log_probs)
|
||||
response_idx = torch.tensor(response_idx)
|
||||
|
||||
if attention_mask.size(0) != action_mask.size(0):
|
||||
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||
num_returns = action_mask.size(0) // attention_mask.size(0)
|
||||
@ -209,7 +261,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
"attention_mask": attention_mask,
|
||||
"action_log_probs": log_probs,
|
||||
"action_mask": action_mask,
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
|
||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
||||
|
||||
if "gt_answer" in kwargs:
|
||||
# repeat gt_answer for each prompt.
|
||||
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
|
@ -1,10 +1,14 @@
|
||||
import copy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
|
||||
from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
|
||||
from .producer import SimpleProducer
|
||||
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
with open(path) as f:
|
||||
@ -30,6 +34,7 @@ def launch_distributed(
|
||||
inference_microbatch_size: int,
|
||||
train_batch_size: int,
|
||||
train_microbatch_size: int,
|
||||
train_minibatch_size: int,
|
||||
dataset_config: Dict[str, Any],
|
||||
dataloaders_config: Dict[str, Any],
|
||||
inference_model_config: Dict[str, Any],
|
||||
@ -38,9 +43,18 @@ def launch_distributed(
|
||||
plugin_config: Dict[str, Any],
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
inference_backend: str = "transformers",
|
||||
num_generations: int = 8,
|
||||
master_addr: str = "localhost",
|
||||
master_port: int = 29500,
|
||||
core_algo: str = "GRPO",
|
||||
project_name: Optional[str] = None,
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
raise NotImplementedError(f"{core_algo} is not supported yet.")
|
||||
else:
|
||||
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
||||
|
||||
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
|
||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||
|
||||
@ -65,10 +79,17 @@ def launch_distributed(
|
||||
tokenizer_config=tokenizer_config,
|
||||
microbatch_size=inference_microbatch_size,
|
||||
backend=inference_backend,
|
||||
num_generations=num_generations,
|
||||
)
|
||||
procs.append(producer)
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
generate_config_consumer.update(
|
||||
dict(
|
||||
backend=inference_backend,
|
||||
)
|
||||
)
|
||||
for i in range(num_consumer_procs):
|
||||
consumer = SimpleConsumer.options(num_gpus=1).remote(
|
||||
consumer = core_consumer.options(num_gpus=1).remote(
|
||||
num_producers=num_producers,
|
||||
num_episodes=num_episodes,
|
||||
rank=i,
|
||||
@ -80,7 +101,15 @@ def launch_distributed(
|
||||
batch_size=train_batch_size,
|
||||
model_config=train_model_config,
|
||||
plugin_config=plugin_config,
|
||||
microbatch_size=train_microbatch_size,
|
||||
microbatch_size=train_minibatch_size,
|
||||
generate_config=generate_config_consumer,
|
||||
training_config={
|
||||
"filter_range": [0.05, 9.0],
|
||||
"lr": 1e-6,
|
||||
"train_microbatch_size": train_microbatch_size,
|
||||
},
|
||||
num_generations=num_generations,
|
||||
project_name=project_name,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
|
45
applications/ColossalChat/coati/distributed/loss.py
Normal file
45
applications/ColossalChat/coati/distributed/loss.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.distributed.utils import masked_mean
|
||||
|
||||
|
||||
class PolicyLoss(nn.Module):
|
||||
"""
|
||||
Policy Loss for PPO
|
||||
"""
|
||||
|
||||
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None:
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
self.skip_threshold = skip_threshold
|
||||
self.beta = beta
|
||||
|
||||
def forward(
|
||||
self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
per_token_kl: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
loss_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
skip = False
|
||||
if action_mask is None:
|
||||
ratio = (log_probs - log_probs.detach()).exp()
|
||||
else:
|
||||
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
|
||||
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
|
||||
|
||||
if action_mask is not None:
|
||||
loss = masked_mean(loss, action_mask)
|
||||
else:
|
||||
loss = loss.mean(dim=1)
|
||||
if loss_mask is not None:
|
||||
loss = loss * loss_mask
|
||||
loss = loss.mean()
|
||||
return loss, skip, ratio.max()
|
@ -100,7 +100,11 @@ class BaseProducer:
|
||||
if i >= num_valid_microbatches:
|
||||
break
|
||||
outputs = self.rollout(**batch)
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
outputs = pre_send(outputs)
|
||||
ray_broadcast_tensor_dict(
|
||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||
@ -113,10 +117,19 @@ class BaseProducer:
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
self.load_state_dict(state_dict)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
# linear annealing for 1 episode, temperature from initial to 0.7
|
||||
if episode <= 0:
|
||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.7
|
||||
|
||||
|
||||
@ray.remote
|
||||
@ -135,6 +148,7 @@ class SimpleProducer(BaseProducer):
|
||||
tokenizer_config=None,
|
||||
microbatch_size=1,
|
||||
backend="transformers",
|
||||
num_generations: int = 8,
|
||||
):
|
||||
super().__init__(
|
||||
producer_idx,
|
||||
@ -150,11 +164,15 @@ class SimpleProducer(BaseProducer):
|
||||
microbatch_size,
|
||||
backend,
|
||||
)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
return self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
if self.producer_idx == 1:
|
||||
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
||||
|
||||
return rollouts
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
@ -0,0 +1,61 @@
|
||||
import torch
|
||||
|
||||
from .reward_utils import extract_solution, validate_response_structure
|
||||
|
||||
|
||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
format_score = 1.0
|
||||
acc_score = 9.0
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
reward = torch.tensor(0.0)
|
||||
format_reward = torch.tensor(0.0)
|
||||
acc_reward = torch.tensor(0.0)
|
||||
s, e = response_idx[0], response_idx[1]
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||
|
||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||
|
||||
# Check format accuracy
|
||||
if format_valid:
|
||||
format_reward += format_score
|
||||
reward += format_score
|
||||
|
||||
# Check answer accuracy
|
||||
if (
|
||||
final_answer is not None
|
||||
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
||||
):
|
||||
acc_reward += acc_score
|
||||
reward += acc_score
|
||||
|
||||
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
|
||||
|
||||
|
||||
def gsm8k_reward_fn(input_ids, **kwargs):
|
||||
gt_answer = kwargs["gt_answer"]
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
s, e = kwargs["response_start"], kwargs["response_end"]
|
||||
reward = torch.tensor(0.0).to(input_ids.device)
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||
is_valid = True
|
||||
try:
|
||||
int(final_answer.strip())
|
||||
except Exception:
|
||||
is_valid = False
|
||||
|
||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||
if not is_valid or not format_valid:
|
||||
return reward
|
||||
else:
|
||||
reward += 1.0
|
||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||
reward = reward + 9.0
|
||||
return reward
|
@ -0,0 +1,76 @@
|
||||
# Copyright Unakar
|
||||
# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
|
||||
def validate_response_structure(processed_str: str, tags: Dict = None) -> bool:
|
||||
"""Performs comprehensive validation of response structure.
|
||||
|
||||
Args:
|
||||
processed_str: Processed response string from the model
|
||||
|
||||
Returns:
|
||||
Boolean indicating whether all formatting requirements are met
|
||||
"""
|
||||
validation_passed = True
|
||||
# Check required tags
|
||||
if tags is None:
|
||||
tags = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
positions = {}
|
||||
for tag_name, tag_info in tags.items():
|
||||
tag_str = tag_info["text"]
|
||||
expected_count = tag_info["num_occur"]
|
||||
count = processed_str.count(tag_str)
|
||||
positions[tag_name] = pos = processed_str.find(tag_str)
|
||||
if count != expected_count:
|
||||
validation_passed = False
|
||||
# Verify tag order
|
||||
if (
|
||||
positions["think_start"] > positions["think_end"]
|
||||
or positions["think_end"] > positions["answer_start"]
|
||||
or positions["answer_start"] > positions["answer_end"]
|
||||
):
|
||||
validation_passed = False
|
||||
if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]):
|
||||
validation_passed = False
|
||||
return validation_passed
|
||||
|
||||
|
||||
def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
|
||||
"""Extracts the final answer from the model's response string.
|
||||
|
||||
Args:
|
||||
solution_str: Raw response string from the language model
|
||||
|
||||
Returns:
|
||||
Tuple containing (extracted_answer, processed_string)
|
||||
"""
|
||||
|
||||
# Extract final answer using XML-style tags
|
||||
answer_pattern = r"<answer>(.*?)</answer>"
|
||||
matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
|
||||
|
||||
if not matches:
|
||||
return None, solution_str
|
||||
|
||||
final_answer = matches[-1].group(1).strip()
|
||||
return final_answer, solution_str
|
@ -0,0 +1,43 @@
|
||||
"""
|
||||
Function-based reward verification module.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class VerifiableReward:
|
||||
def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]):
|
||||
self.reward_fns = reward_fns
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
gt_answer: List[torch.Tensor] = None,
|
||||
response_idx: List[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Get batch size
|
||||
bs = input_ids.size(0)
|
||||
# Initialize reward
|
||||
rewards = torch.zeros((bs, 3), device=input_ids.device)
|
||||
|
||||
# Loop through reward functions
|
||||
for reward_fn in self.reward_fns:
|
||||
# Apply the reward function to the entire batch at once
|
||||
reward_batch = torch.stack(
|
||||
[
|
||||
reward_fn(
|
||||
input_ids[i],
|
||||
gt_answer=gt_answer[i],
|
||||
response_idx=response_idx[i],
|
||||
**self.kwargs,
|
||||
)
|
||||
for i in range(bs)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
rewards += reward_batch
|
||||
return rewards
|
@ -2,6 +2,8 @@ from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.shardformer.layer.loss import dist_log_prob
|
||||
|
||||
|
||||
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
|
||||
batches = []
|
||||
@ -64,3 +66,50 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
|
||||
log_probs = torch.log_softmax(logits, dim=-1)
|
||||
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return per_label_logps.squeeze(-1)
|
||||
|
||||
|
||||
def calc_action_log_probs(
|
||||
logits: torch.Tensor,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int,
|
||||
shard_config,
|
||||
vocab_size: int = None,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate action log probs.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||
sequences (torch.LongTensor): Input sequences.
|
||||
num_actions (int): Number of actions.
|
||||
shard_config
|
||||
vocab_size
|
||||
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||
# logits: torch.Tensor, # [B, S, Vocab_size]
|
||||
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
|
||||
log_probs = log_probs.squeeze(-1)
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
|
||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
Compute the masked mean of a tensor along a specified dimension.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The input tensor.
|
||||
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
|
||||
dim (int, optional): The dimension along which to compute the mean. Default is 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The masked mean tensor.
|
||||
|
||||
"""
|
||||
tensor = tensor * mask
|
||||
tensor = tensor.sum(dim=dim)
|
||||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
@ -10,54 +10,83 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
|
||||
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16)
|
||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
|
||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
||||
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||
parser.add_argument(
|
||||
"-ibs",
|
||||
"--inference-batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-imbs",
|
||||
"--inference-microbatch-size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tbs",
|
||||
"--train-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tMbs",
|
||||
"--train-minibatch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tmbs",
|
||||
"--train-microbatch-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||
)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
|
||||
assert (
|
||||
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
|
||||
and args.train_microbatch_size > 0
|
||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model)
|
||||
generate_config = dict(
|
||||
top_k=50,
|
||||
top_p=0.8,
|
||||
)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
||||
|
||||
if args.backend == "transformers":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
attn_implementation="flash_attention_2",
|
||||
use_flash_attention_2=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
)
|
||||
train_model_config.update(
|
||||
dict(
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_length=512,
|
||||
max_length=1024 + 512,
|
||||
do_sample=True,
|
||||
max_new_tokens=None,
|
||||
early_stopping=False,
|
||||
stop_strings=["</answer>"],
|
||||
)
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
gpu_memory_utilization=0.6,
|
||||
)
|
||||
)
|
||||
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=256,
|
||||
max_tokens=2048,
|
||||
ignore_eos=True,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -77,18 +106,29 @@ if __name__ == "__main__":
|
||||
num_producers=args.num_inferencer,
|
||||
num_proc_per_producer=1,
|
||||
num_consumer_procs=args.num_trainers,
|
||||
num_episodes=1,
|
||||
num_episodes=10,
|
||||
inference_batch_size=args.inference_batch_size,
|
||||
inference_microbatch_size=args.inference_microbatch_size,
|
||||
train_batch_size=args.train_batch_size,
|
||||
train_minibatch_size=args.train_minibatch_size,
|
||||
train_microbatch_size=args.train_microbatch_size,
|
||||
dataset_config={"path": args.dataset, "max_length": 256},
|
||||
dataset_config={"path": args.dataset, "max_length": 300},
|
||||
dataloaders_config={},
|
||||
inference_model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
num_generations=args.num_generations,
|
||||
train_model_config=train_model_config,
|
||||
plugin_config={},
|
||||
# plugin_config={}, # for zero
|
||||
plugin_config={
|
||||
"pp_size": 2,
|
||||
"tp_size": 2,
|
||||
"microbatch_size": args.train_microbatch_size // 2,
|
||||
"zero_stage": 0,
|
||||
"max_norm": 1.0,
|
||||
}, # for pp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=29504,
|
||||
master_port=29506,
|
||||
core_algo=args.algo,
|
||||
project_name=args.project,
|
||||
)
|
||||
|
@ -1411,8 +1411,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
)
|
||||
|
||||
# run with gradients accumulation
|
||||
if model.require_grad_sync == False or (
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
||||
if (
|
||||
not torch.is_grad_enabled()
|
||||
or model.require_grad_sync == False
|
||||
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
|
||||
):
|
||||
return outputs
|
||||
|
||||
|
@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||
from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import (
|
||||
@ -28,6 +28,8 @@ __all__ = [
|
||||
"DropoutForReplicatedInput",
|
||||
"cross_entropy_1d",
|
||||
"dist_cross_entropy",
|
||||
"dist_log_prob_1d",
|
||||
"dist_log_prob",
|
||||
"BaseLayerNorm",
|
||||
"LayerNorm",
|
||||
"RMSNorm",
|
||||
|
@ -3,13 +3,21 @@ import torch.distributed as dist
|
||||
from torch.autograd import Function
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.functional import log_softmax
|
||||
|
||||
from colossalai.shardformer.layer._operation import reduce_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from .utils import is_share_sp_tp
|
||||
|
||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
|
||||
__all__ = [
|
||||
"DistCrossEntropy",
|
||||
"cross_entropy_1d",
|
||||
"dist_cross_entropy",
|
||||
"DistLogProb",
|
||||
"dist_log_prob_1d",
|
||||
"dist_log_prob",
|
||||
]
|
||||
|
||||
_IGNORE_IDX = -100
|
||||
|
||||
@ -137,6 +145,98 @@ class DistCrossEntropy(Function):
|
||||
return grad_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
class DistLogProb(Function):
|
||||
r"""
|
||||
Overwrite the forward and backward function to calculate the log prob before gather
|
||||
|
||||
Args:
|
||||
Function (:class:`torch.autograd.Function`): default
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
vocab_logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
process_group: ProcessGroup,
|
||||
vocab_size: int,
|
||||
dtype=torch.float32,
|
||||
):
|
||||
|
||||
##################
|
||||
# Step1:Find the global maximum value of logits
|
||||
##################
|
||||
logits_max = torch.max(vocab_logits, dim=-1)[0]
|
||||
handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
|
||||
|
||||
##################
|
||||
# Step2:Find the local mask. local mask will be use to select log_probs value in Step 4.
|
||||
# For accleration, we overlap Step 2 and Step 3
|
||||
##################
|
||||
rank = dist.get_rank(group=process_group)
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
if vocab_size is None:
|
||||
partition_vocab_size = vocab_logits.size()[-1]
|
||||
global_vocab_size = partition_vocab_size * world_size
|
||||
else:
|
||||
global_vocab_size = vocab_size
|
||||
partition_vocab_size = global_vocab_size // world_size
|
||||
# down and up threshold for local logits
|
||||
delta = (global_vocab_size + world_size - 1) // world_size
|
||||
down_threshold = rank * delta
|
||||
up_threshold = down_threshold + delta
|
||||
if up_threshold > global_vocab_size:
|
||||
up_threshold = global_vocab_size
|
||||
# mask
|
||||
mask = (target < down_threshold) | (target >= up_threshold)
|
||||
masked_target = target.clone() - down_threshold
|
||||
masked_target[mask] = 0
|
||||
masked_target_1d = masked_target.view(-1).contiguous()
|
||||
handle.wait()
|
||||
|
||||
##################
|
||||
# Step3:Calculate global summation exp logits
|
||||
##################
|
||||
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
||||
exp_logits = torch.exp(vocab_logits)
|
||||
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits
|
||||
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
|
||||
|
||||
##################
|
||||
# Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask
|
||||
##################
|
||||
log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax
|
||||
log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1))
|
||||
log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero
|
||||
dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group)
|
||||
|
||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits)
|
||||
ctx.dtype = dtype
|
||||
return log_probs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors
|
||||
##################
|
||||
# Step1:Find the global sofmax value
|
||||
##################
|
||||
softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1)
|
||||
|
||||
##################
|
||||
# Step2:Update softmax value based on local target index
|
||||
##################
|
||||
partion_vocab_size = softmax_logits.shape[-1]
|
||||
softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size)
|
||||
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
|
||||
softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update
|
||||
|
||||
##################
|
||||
# Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax
|
||||
##################
|
||||
grad_logits = -softmax_logits.mul_(grad_output)
|
||||
return grad_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
def cross_entropy_1d(
|
||||
vocab_logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
@ -149,6 +249,16 @@ def cross_entropy_1d(
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
|
||||
|
||||
|
||||
def dist_log_prob_1d(
|
||||
vocab_logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
process_group: ProcessGroup = None,
|
||||
vocab_size: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
) -> torch.Tensor:
|
||||
return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype)
|
||||
|
||||
|
||||
def dist_cross_entropy(
|
||||
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||
logits: torch.Tensor, # [B, S, Vocab_size]
|
||||
@ -243,3 +353,41 @@ def dist_cross_entropy(
|
||||
loss, num_nonzero = loss[0], loss[1].detach()
|
||||
loss = (loss / num_nonzero).squeeze()
|
||||
return loss
|
||||
|
||||
|
||||
def dist_log_prob(
|
||||
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||
logits: torch.Tensor, # [B, S, Vocab_size]
|
||||
shard_config: ShardConfig,
|
||||
vocab_size: int,
|
||||
dtype: torch.dtype,
|
||||
seq_dim: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper to compute log prob for most shardformer models supporting PP, TP.
|
||||
"""
|
||||
# Split labels if not gather output
|
||||
parallel_output = shard_config.parallel_output
|
||||
is_tp = shard_config.enable_tensor_parallelism
|
||||
|
||||
# TODO:support sp
|
||||
labels = labels[..., 1:]
|
||||
logits = logits[..., :-1, :]
|
||||
labels = labels.contiguous()
|
||||
logits = logits.contiguous()
|
||||
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
|
||||
|
||||
# Flatten the tokens
|
||||
if is_tp and parallel_output:
|
||||
log_prob = dist_log_prob_1d(
|
||||
logits,
|
||||
labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=vocab_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
log_prob = log_softmax(logits, dim=-1)
|
||||
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
|
||||
return log_prob
|
||||
|
@ -284,6 +284,7 @@ class Qwen2PipelineForwards:
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@ -832,7 +833,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
@ -13,6 +13,7 @@ from colossalai.shardformer.layer import (
|
||||
PaddingEmbedding,
|
||||
RMSNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from ..modeling.qwen2 import (
|
||||
@ -429,8 +430,12 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs=dict(
|
||||
gather_output=not self.shard_config.parallel_output,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
],
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||
@ -446,7 +451,16 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
||||
suffix="lm_head",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
|
||||
)
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
],
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||
)
|
||||
|
52
tests/test_shardformer/test_layer/test_dist_log_prob.py
Normal file
52
tests/test_shardformer/test_layer/test_dist_log_prob.py
Normal file
@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
import torch
|
||||
from coati.distributed.utils import log_probs_from_logits
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer import dist_log_prob_1d
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
|
||||
)
|
||||
|
||||
|
||||
def check_dist_log_prob(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
||||
|
||||
# prepare data
|
||||
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
|
||||
labels = torch.randint(8, (2, 4)).cuda()
|
||||
|
||||
logprob = log_probs_from_logits(pred, labels)
|
||||
|
||||
pred.retain_grad()
|
||||
logprob.mean().backward()
|
||||
|
||||
dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
|
||||
dist_pred.requires_grad = True
|
||||
dist_logprob = dist_log_prob_1d(dist_pred, labels)
|
||||
|
||||
dist_pred.retain_grad()
|
||||
dist_logprob.squeeze(-1).mean().backward()
|
||||
|
||||
assert torch.allclose(
|
||||
logprob, dist_logprob.squeeze(-1), atol=1e-5
|
||||
), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}"
|
||||
|
||||
pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach()
|
||||
assert torch.allclose(
|
||||
pred_grad_partial, dist_pred.grad
|
||||
), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_log_prob():
|
||||
spawn(check_dist_log_prob, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dist_log_prob()
|
Loading…
Reference in New Issue
Block a user