diff --git a/.gitignore b/.gitignore
index 8bc74b4c8..533450a7c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -163,3 +163,5 @@ coverage.xml
# log, test files - ColossalChat
applications/ColossalChat/logs
applications/ColossalChat/tests/logs
+applications/ColossalChat/wandb
+applications/ColossalChat/model
diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py
index fc3a4930c..4518fd71f 100755
--- a/applications/ColossalChat/coati/dataset/loader.py
+++ b/applications/ColossalChat/coati/dataset/loader.py
@@ -356,10 +356,24 @@ def apply_chat_template_and_mask(
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
+
+ system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, i.e., 123 .\n\n"
+
+ system_element = {
+ "role": "system",
+ "content": system_prompt,
+ }
+
+ # Format for RL.
+ gt_answer = None
+ if "messages" in chat and "gt_answer" in chat:
+ gt_answer = chat["gt_answer"]
+ chat = [chat["messages"]]
+
tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
- msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
+ msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
# remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:]
@@ -372,14 +386,10 @@ def apply_chat_template_and_mask(
if max_length is not None:
if padding and len(tokens) < max_length:
to_pad = max_length - len(tokens)
- if tokenizer.padding_side == "right":
- tokens.extend([tokenizer.pad_token_id] * to_pad)
- assistant_mask.extend([False] * to_pad)
- attention_mask.extend([0] * to_pad)
- else:
- tokens = [tokenizer.pad_token_id] * to_pad + tokens
- assistant_mask = [False] * to_pad + assistant_mask
- attention_mask = [0] * to_pad + attention_mask
+ # Left padding for generation.
+ tokens = [tokenizer.pad_token_id] * to_pad + tokens
+ assistant_mask = [False] * to_pad + assistant_mask
+ attention_mask = [0] * to_pad + attention_mask
if truncation and len(tokens) > max_length:
tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length]
@@ -389,6 +399,13 @@ def apply_chat_template_and_mask(
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
+ if gt_answer is not None:
+ gt_answer = tokenizer.encode(
+ gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
+ )
+ gt_answer = gt_answer.squeeze(1)
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
+
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py
index 84a69979f..79beb2a2d 100644
--- a/applications/ColossalChat/coati/distributed/consumer.py
+++ b/applications/ColossalChat/coati/distributed/consumer.py
@@ -1,3 +1,4 @@
+import os
from contextlib import nullcontext
from typing import Any, Dict, Optional
@@ -33,6 +34,8 @@ class BaseConsumer:
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
+ save_interval: int = 100,
+ save_dir: str = "./model",
):
self.num_producers = num_producers
self.num_episodes = num_episodes
@@ -44,14 +47,16 @@ class BaseConsumer:
self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size
self.microbatch_size = microbatch_size
+ self.save_interval = save_interval
+ self.save_dir = save_dir
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size
self.model_config = model_config
self.plugin_config = plugin_config
- assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
self.device = get_current_device()
+ self.lr_scheduler = None
def setup(self) -> None:
for i in range(self.num_producers):
@@ -60,18 +65,15 @@ class BaseConsumer:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
- plugin_config = dict(
- tp_size=1,
- pp_size=1,
- precision="bf16",
- zero_stage=1,
- )
+ plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group)
+ self.tp_rank = dist.get_rank(self.plugin.tp_group)
+
self.dp_size = dist.get_world_size(self.plugin.dp_group)
self.buffer = []
@@ -94,7 +96,6 @@ class BaseConsumer:
i = 0
for _ in range(self.num_recv_per_update):
# receive data from producers
-
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
@@ -116,13 +117,26 @@ class BaseConsumer:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
+ if self.lr_scheduler is not None:
+ self.lr_scheduler.step()
+ if (step + 1) % self.save_interval == 0:
+ if self.rank == 0:
+ print(f"Start saving policy model at step {step + 1}.")
+ save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
+ self.booster.save_model(self.policy_model, save_path, shard=True)
+ if self.rank == 0:
+ print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
+
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
+ torch.cuda.empty_cache()
state_dict = self.state_dict()
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
+ del state_dict
+ torch.cuda.empty_cache()
@ray.remote
diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
new file mode 100644
index 000000000..e23254d1b
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -0,0 +1,529 @@
+import json
+import os
+from contextlib import nullcontext
+from typing import Optional
+
+import ray
+import torch
+import torch.distributed as dist
+import wandb
+from coati.distributed.consumer import BaseConsumer
+from coati.distributed.loss import PolicyLoss
+from coati.distributed.reward.reward_fn import math_reward_fn
+from coati.distributed.reward.verifiable_reward import VerifiableReward
+from coati.distributed.utils import calc_action_log_probs
+from coati.trainer.utils import all_reduce_mean
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+
+
+@ray.remote
+class GRPOConsumer(BaseConsumer):
+ def __init__(
+ self,
+ num_producers,
+ num_episodes,
+ rank,
+ world_size,
+ master_addr,
+ master_port,
+ num_update_per_episode,
+ num_recv_per_update,
+ batch_size,
+ model_config,
+ plugin_config,
+ microbatch_size=1,
+ num_generations=8,
+ use_wandb=True,
+ generate_config=None,
+ training_config={},
+ project_name=None,
+ ):
+ super().__init__(
+ num_producers,
+ num_episodes,
+ rank,
+ world_size,
+ master_addr,
+ master_port,
+ num_update_per_episode,
+ num_recv_per_update,
+ batch_size,
+ model_config,
+ plugin_config,
+ microbatch_size,
+ )
+ path = model_config.pop("path")
+ self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
+ self.policy_model.train()
+ self.policy_model.gradient_checkpointing_enable()
+ self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
+ self.accum_loss = torch.zeros(1, device=self.device)
+ self.accum_reward = torch.zeros(1, device=self.device)
+ self.accum_kl = torch.zeros(1, device=self.device)
+ self.accum_format_reward = torch.zeros(1, device=self.device)
+ self.accum_acc_reward = torch.zeros(1, device=self.device)
+ self.accum_advantages = torch.zeros(1, device=self.device)
+ self.accum_response_length = torch.zeros(1, device=self.device)
+ self.accum_count = 0
+ self.generate_config = generate_config
+ self.training_config = training_config
+ self.project_name = project_name
+
+ # Reference model is initialized from policy model.
+ self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
+ self.reference_model.eval()
+
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
+ self.pad_token_id = self.tokenizer.pad_token_id
+ self.num_generations = num_generations
+ self.filter_range = training_config.get("filter_range", None)
+ if self.filter_range is not None:
+ assert len(self.filter_range) == 2, "Filter range should have 2 values."
+
+ # Initialize verifiable reward.
+ response_format_tags = {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ self.reward_model = VerifiableReward(
+ reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
+ )
+
+ self.policy_loss_fn = PolicyLoss()
+ self.global_step = 0
+ self.use_wandb = use_wandb
+
+ self.lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=self.optimizer,
+ total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
+ warmup_steps=0,
+ eta_min=0.1 * training_config.get("lr", 1e-6),
+ )
+
+ def setup(self):
+ super().setup()
+ if self.use_wandb and (
+ (not self.plugin.pp_size > 1 and self.rank == 0)
+ or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
+ ):
+ # Initialize wandb.
+ name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
+ self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
+
+ self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
+ self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
+ )
+ self.reference_model, *_ = self.booster.boost(self.reference_model)
+ self.plugin.logger.set_level("ERROR")
+
+ def step(self, step_idx: int, **kwargs) -> Optional[float]:
+ """
+ Step data from policy model:
+ [{
+ "input_ids": torch.Tensor,
+ "attention_mask": torch.Tensor,
+ "action_mask": torch.Tensor,
+ "action_log_probs": torch.Tensor,
+ },
+ ...]
+ Format:
+ [batch_size, num_of_generation, prompt_length + response_length] --- .............
+ """
+
+ # Reshape to [batch_size x num_of_generation, prompt_length + response_length]
+ data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
+ action_mask = data["action_mask"]
+ num_action = action_mask.shape[1]
+ old_action_log_probs = data["action_log_probs"]
+ response_length = torch.sum(action_mask, dim=1).to(torch.float32)
+ forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
+
+ need_update = (step_idx + 1) % self.num_microbatches == 0
+ # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
+ ctx = (
+ nullcontext()
+ if need_update or self.booster.plugin.zero_stage == 2
+ else self.booster.no_sync(self.policy_model, self.optimizer)
+ )
+ with ctx:
+ reward_group = self.reward_model(
+ data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
+ )
+
+ reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
+ format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
+ acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
+
+ # [batch_size, num_generations]
+
+ group_reward = reward.view(-1, self.num_generations)
+ reward_mean = group_reward.mean(dim=1)
+ # [batch_size x num_generations]
+ reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
+ reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
+ # [batch_size x num_generations]
+ advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
+ # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
+ loss_mask = (
+ None
+ if self.filter_range is None
+ else torch.logical_and(
+ reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
+ ).repeat_interleave(self.num_generations, dim=0)
+ )
+ mean_kl, mean_loss = [], []
+
+ for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
+ input_ids_forward_micro_batch = data["input_ids"][
+ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
+ ]
+ attention_mask_forward_micro_batch = data["attention_mask"][
+ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
+ ]
+ action_mask_forward_micro_batch = action_mask[
+ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
+ ]
+ loss_mask_forward_micro_batch = (
+ loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
+ if loss_mask is not None
+ else None
+ )
+ advantages_forward_micro_batch = advantages[
+ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
+ ]
+
+ if self.plugin.pp_size > 1:
+ # Support training with PP.
+
+ with torch.no_grad():
+ reference_model_outputs = self.booster.execute_pipeline(
+ iter(
+ [
+ {
+ "input_ids": input_ids_forward_micro_batch,
+ "attention_mask": attention_mask_forward_micro_batch,
+ }
+ ]
+ ),
+ self.reference_model,
+ criterion=lambda outputs, inputs: torch.tensor(
+ [0.0], device=action_mask.device
+ ), # dummy criterion
+ optimizer=None,
+ return_loss=False,
+ return_outputs=True,
+ )
+
+ if self.booster.plugin.stage_manager.is_last_stage():
+ reference_model_logits = reference_model_outputs["outputs"]["logits"]
+ reference_action_log_probs = calc_action_log_probs(
+ reference_model_logits / self.generate_config["temperature"],
+ input_ids_forward_micro_batch,
+ num_action,
+ self.plugin.shard_config,
+ )
+ else:
+ # Dummy reference logprobs for data iterator.
+ reference_action_log_probs = None
+
+ data_policy_forward = {
+ "input_ids": input_ids_forward_micro_batch,
+ "attention_mask": attention_mask_forward_micro_batch,
+ "action_mask": action_mask_forward_micro_batch,
+ "reference_action_log_probs": reference_action_log_probs,
+ "advantages": advantages_forward_micro_batch,
+ "loss_mask": loss_mask_forward_micro_batch,
+ "source": self.rank,
+ }
+
+ kl = []
+
+ def _criterion(outputs, inputs):
+ action_logits = outputs.logits
+ action_log_probs = calc_action_log_probs(
+ action_logits / self.generate_config["temperature"],
+ inputs["input_ids"],
+ num_action,
+ self.plugin.shard_config,
+ )
+ per_token_kl = (
+ torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
+ - (inputs["reference_action_log_probs"] - action_log_probs)
+ - 1
+ )
+ appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
+ inputs["action_mask"], dim=-1
+ )
+ kl.append(appox_kl.mean())
+ loss, skip_update, _ = self.policy_loss_fn(
+ action_log_probs,
+ action_log_probs,
+ inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
+ per_token_kl,
+ inputs["action_mask"],
+ loss_mask=inputs["loss_mask"],
+ )
+ return loss
+
+ policy_model_outputs = self.booster.execute_pipeline(
+ iter([data_policy_forward]),
+ self.policy_model,
+ criterion=_criterion,
+ optimizer=self.optimizer,
+ return_loss=True,
+ return_outputs=True,
+ )
+ loss = policy_model_outputs["loss"]
+
+ if self.booster.plugin.stage_manager.is_last_stage():
+ if len(kl) > 0:
+ kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
+ mean_kl.append(kl)
+ mean_loss.append(all_reduce_mean(loss, self.plugin).data)
+ else:
+
+ policy_model_logits = self.policy_model(
+ input_ids=input_ids_forward_micro_batch,
+ attention_mask=attention_mask_forward_micro_batch,
+ ).logits
+ action_log_probs = calc_action_log_probs(
+ policy_model_logits / self.generate_config["temperature"],
+ input_ids_forward_micro_batch,
+ num_action,
+ self.plugin.shard_config,
+ )
+
+ with torch.no_grad():
+ reference_model_logits = self.reference_model(
+ input_ids=input_ids_forward_micro_batch,
+ attention_mask=attention_mask_forward_micro_batch,
+ ).logits
+ reference_action_log_probs = calc_action_log_probs(
+ reference_model_logits / self.generate_config["temperature"],
+ input_ids_forward_micro_batch,
+ num_action,
+ self.plugin.shard_config,
+ )
+ per_token_kl = (
+ torch.exp(reference_action_log_probs - action_log_probs)
+ - (reference_action_log_probs - action_log_probs)
+ - 1
+ )
+ kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
+ action_mask_forward_micro_batch, dim=-1
+ )
+
+ loss, skip_update, _ = self.policy_loss_fn(
+ action_log_probs,
+ old_action_log_probs,
+ advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
+ per_token_kl,
+ action_mask_forward_micro_batch,
+ loss_mask=loss_mask_forward_micro_batch,
+ )
+
+ if not skip_update:
+ self.booster.backward(loss, self.optimizer)
+ loss = all_reduce_mean(loss, self.plugin)
+ kl = all_reduce_mean(kl.mean(), self.plugin)
+ # Calculate accumulate value.
+ mean_kl.append(kl.data)
+ mean_loss.append(loss.data)
+ if not self.plugin.pp_size > 1 or (
+ self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
+ ):
+ reward = all_reduce_mean(reward.mean(), self.plugin)
+ format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
+ acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
+ advantages = all_reduce_mean(advantages.mean(), self.plugin)
+ response_length = all_reduce_mean(response_length.mean(), self.plugin)
+ self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
+ self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
+ self.accum_reward.add_(reward.data)
+ self.accum_format_reward.add_(format_reward.data)
+ self.accum_acc_reward.add_(acc_reward.data)
+ self.accum_advantages.add_(advantages.data)
+ self.accum_response_length.add_(response_length.data)
+ self.accum_count += 1
+ if need_update:
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+ if not self.plugin.pp_size > 1 or (
+ self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
+ ):
+ loss_scalar = self.accum_loss.item()
+ if (not self.plugin.pp_size > 1 and self.rank == 0) or (
+ self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
+ ):
+ print(
+ "Loss:",
+ self.accum_loss.item() / self.accum_count,
+ "\nReward:",
+ self.accum_reward.item() / self.accum_count,
+ "\nFormat Reward:",
+ self.accum_format_reward.item() / self.accum_count,
+ "\nAcc Reward:",
+ self.accum_acc_reward.item() / self.accum_count,
+ "\nKL:",
+ self.accum_kl.item() / self.accum_count,
+ "\nAdvantages:",
+ self.accum_advantages.item() / self.accum_count,
+ "\nResponse Length:",
+ self.accum_response_length.item() / self.accum_count,
+ )
+ self.wandb_run.log(
+ {
+ "metrics/reward": self.accum_reward.item() / self.accum_count,
+ "metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
+ "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
+ "metrics/response_length": self.accum_response_length.item() / self.accum_count,
+ "train/loss": self.accum_loss.item() / self.accum_count,
+ "train/kl": self.accum_kl.item() / self.accum_count,
+ "train/advantages": self.accum_advantages.item() / self.accum_count,
+ "train/learning_rate": self.lr_scheduler.get_last_lr()[0],
+ "rollout/temperature": data["temperature"].cpu().numpy()[0][0],
+ }
+ )
+ self.accum_loss.zero_()
+ self.accum_reward.zero_()
+ self.accum_acc_reward.zero_()
+ self.accum_format_reward.zero_()
+ self.accum_kl.zero_()
+ self.accum_advantages.zero_()
+ self.accum_response_length.zero_()
+
+ self.accum_count = 0
+ return loss_scalar
+
+ def state_dict(self):
+ self.policy_model._force_wait_all_gather()
+ model = self.policy_model.unwrap()
+ state_dict = model.state_dict()
+ return state_dict
+
+
+@ray.remote
+class GRPOEvalConsumer(BaseConsumer):
+ def __init__(
+ self,
+ num_producers,
+ num_episodes,
+ rank,
+ world_size,
+ master_addr,
+ master_port,
+ num_update_per_episode,
+ num_recv_per_update,
+ batch_size,
+ model_config,
+ plugin_config,
+ microbatch_size=1,
+ num_generations=4,
+ use_wandb=True,
+ log_dir="./results",
+ ):
+ super().__init__(
+ num_producers,
+ num_episodes,
+ rank,
+ world_size,
+ master_addr,
+ master_port,
+ num_update_per_episode,
+ num_recv_per_update,
+ batch_size,
+ model_config,
+ plugin_config,
+ microbatch_size,
+ )
+ path = model_config.pop("path")
+ self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
+ self.policy_model.train()
+ self.accum_reward = torch.zeros(1, device=self.device)
+ self.accum_format_reward = torch.zeros(1, device=self.device)
+ self.accum_acc_reward = torch.zeros(1, device=self.device)
+ self.accum_response_length = torch.zeros(1, device=self.device)
+ self.accum_count = torch.zeros(1, device=self.device)
+
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
+ self.pad_token_id = self.tokenizer.pad_token_id
+ self.num_generations = num_generations
+
+ # Initialize verifiable reward.
+ response_format_tags = {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ self.reward_model = VerifiableReward(
+ reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
+ )
+
+ self.log_dir = log_dir
+ if not os.path.exists(self.log_dir):
+ os.makedirs(self.log_dir)
+ else:
+ os.system(f"rm -rf {self.log_dir}/*")
+
+ def setup(self):
+ super().setup()
+ self.policy_model, _, *_ = self.booster.boost(self.policy_model)
+
+ def step(self, step_idx: int, **kwargs) -> Optional[float]:
+ rank = dist.get_rank()
+ data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
+ kwargs["input_ids"].size(0)
+ reward_group = self.reward_model(
+ data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
+ )
+ reward = [value[0].item() for value in reward_group]
+ format_reward = [value[1].item() for value in reward_group]
+ acc_reward = [value[2].item() for value in reward_group]
+ response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
+
+ response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
+ with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
+ for i in range(len(response)):
+ f.write(
+ json.dumps(
+ {
+ "response": response[i],
+ "reward": reward[i],
+ "format_reward": format_reward[i],
+ "acc_reward": acc_reward[i],
+ "response_length": response_length[i],
+ },
+ ensure_ascii=False,
+ )
+ + "\n"
+ )
+
+ self.accum_reward += sum(reward)
+ self.accum_format_reward += sum(format_reward)
+ self.accum_acc_reward += sum(acc_reward)
+ self.accum_response_length += sum(response_length)
+ self.accum_count += len(reward)
+
+ # print results
+ total_count = all_reduce_mean(self.accum_count, self.plugin)
+ mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
+ mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
+ mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
+ mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
+ if rank == 0:
+ print(
+ f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
+ )
+ return None
+
+ def state_dict(self):
+ self.policy_model._force_wait_all_gather()
+ model = self.policy_model.unwrap()
+ state_dict = model.state_dict()
+ return state_dict
diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py
index 95b7d1e80..17c71c8a8 100644
--- a/applications/ColossalChat/coati/distributed/inference_backend.py
+++ b/applications/ColossalChat/coati/distributed/inference_backend.py
@@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
- def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
+ def __init__(
+ self,
+ model_config: Dict[str, Any],
+ generate_config: Dict[str, Any],
+ tokenizer: PreTrainedTokenizer,
+ num_generations: int = 8,
+ ):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
path = model_config.pop("path")
@@ -61,12 +67,22 @@ class TransformersInferenceBackend(BaseInferenceBackend):
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer
+ self.num_generations = num_generations
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
+ micro_batch_size = input_ids.size(0)
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
- out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
+ gt_answer = None
+ if "gt_answer" in kwargs:
+ gt_answer = kwargs.pop("gt_answer")
+ if self.num_generations > 1:
+ input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
+ attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
+ out = self.model.generate(
+ input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
+ )
input_len = input_ids.shape[-1]
new_token_ids = out.sequences[:, input_len:]
# get log probs
@@ -76,10 +92,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
action_log_probs = torch.cat(action_log_probs, dim=1)
# get action mask
+ response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
if self.tokenizer.eos_token_id is not None:
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
action_mask[indices[0], indices[1] + 1 :] = 0
+ response_idx[:, 0] = input_len
+ response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
@@ -91,7 +110,15 @@ class TransformersInferenceBackend(BaseInferenceBackend):
"attention_mask": attention_mask,
"action_log_probs": action_log_probs,
"action_mask": action_mask,
+ "response_idx": response_idx,
}
+
+ data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
+
+ if gt_answer is not None:
+ # repeat gt_answer for each prompt.
+ data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
+ data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
@@ -99,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
class SGLangInferenceBackend(BaseInferenceBackend):
- def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
+ def __init__(
+ self,
+ model_config: Dict[str, Any],
+ generate_config: Dict[str, Any],
+ tokenizer: PreTrainedTokenizer,
+ num_generations: int = 8,
+ ):
if sgl is None:
raise ImportError("sglang is not installed")
path = model_config.pop("path")
@@ -156,29 +189,46 @@ class VLLMInferenceBackend(BaseInferenceBackend):
logprobs=0,
)
- def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
+ def __init__(
+ self,
+ model_config: Dict[str, Any],
+ generate_config: Dict[str, Any],
+ tokenizer: PreTrainedTokenizer,
+ num_generations: int = 8,
+ ):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path")
- self.llm = LLM(path, **model_config)
+ self.llm = LLM(model=path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
+ generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer
+ self.num_generations = num_generations
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
+ micro_batch_size = input_ids.size(0)
+ response_start_idx = input_ids.size(1)
+ first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
+ micro_batch_input_ids = input_ids.tolist()
+ micro_batch_input_ids_no_padding = [
+ micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
+ ]
outputs = self.llm.generate(
- prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
+ prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
)
out_tokens = []
out_len = []
log_probs = []
+ response_idx = []
for out in outputs:
for output_i in out.outputs:
out_len.append(len(output_i.token_ids))
out_tokens.append(list(output_i.token_ids))
+ response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
assert len(output_i.logprobs) == len(output_i.token_ids)
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
log_probs.append(p)
@@ -195,6 +245,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
out_tokens = torch.tensor(out_tokens)
log_probs = torch.tensor(log_probs)
+ response_idx = torch.tensor(response_idx)
+
if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
num_returns = action_mask.size(0) // attention_mask.size(0)
@@ -209,7 +261,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"attention_mask": attention_mask,
"action_log_probs": log_probs,
"action_mask": action_mask,
+ "response_idx": response_idx,
}
+
+ data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
+
+ if "gt_answer" in kwargs:
+ # repeat gt_answer for each prompt.
+ data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index 438c46300..699d90a8c 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -1,10 +1,14 @@
+import copy
from typing import Any, Dict, Optional
import ray
from .consumer import SimpleConsumer
+from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
from .producer import SimpleProducer
+ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
+
def get_jsonl_size_fast(path: str) -> int:
with open(path) as f:
@@ -30,6 +34,7 @@ def launch_distributed(
inference_microbatch_size: int,
train_batch_size: int,
train_microbatch_size: int,
+ train_minibatch_size: int,
dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
@@ -38,9 +43,18 @@ def launch_distributed(
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
+ num_generations: int = 8,
master_addr: str = "localhost",
master_port: int = 29500,
+ core_algo: str = "GRPO",
+ project_name: Optional[str] = None,
):
+
+ if core_algo not in ALGO_MAP:
+ raise NotImplementedError(f"{core_algo} is not supported yet.")
+ else:
+ core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
+
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
@@ -65,10 +79,17 @@ def launch_distributed(
tokenizer_config=tokenizer_config,
microbatch_size=inference_microbatch_size,
backend=inference_backend,
+ num_generations=num_generations,
)
procs.append(producer)
+ generate_config_consumer = copy.deepcopy(generate_config)
+ generate_config_consumer.update(
+ dict(
+ backend=inference_backend,
+ )
+ )
for i in range(num_consumer_procs):
- consumer = SimpleConsumer.options(num_gpus=1).remote(
+ consumer = core_consumer.options(num_gpus=1).remote(
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,
@@ -80,7 +101,15 @@ def launch_distributed(
batch_size=train_batch_size,
model_config=train_model_config,
plugin_config=plugin_config,
- microbatch_size=train_microbatch_size,
+ microbatch_size=train_minibatch_size,
+ generate_config=generate_config_consumer,
+ training_config={
+ "filter_range": [0.05, 9.0],
+ "lr": 1e-6,
+ "train_microbatch_size": train_microbatch_size,
+ },
+ num_generations=num_generations,
+ project_name=project_name,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])
diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py
new file mode 100644
index 000000000..90ad09736
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/loss.py
@@ -0,0 +1,45 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from coati.distributed.utils import masked_mean
+
+
+class PolicyLoss(nn.Module):
+ """
+ Policy Loss for PPO
+ """
+
+ def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None:
+ super().__init__()
+ self.clip_eps = clip_eps
+ self.skip_threshold = skip_threshold
+ self.beta = beta
+
+ def forward(
+ self,
+ log_probs: torch.Tensor,
+ old_log_probs: torch.Tensor,
+ advantages: torch.Tensor,
+ per_token_kl: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None,
+ loss_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ skip = False
+ if action_mask is None:
+ ratio = (log_probs - log_probs.detach()).exp()
+ else:
+ ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
+
+ surr1 = ratio * advantages
+ surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
+ loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
+
+ if action_mask is not None:
+ loss = masked_mean(loss, action_mask)
+ else:
+ loss = loss.mean(dim=1)
+ if loss_mask is not None:
+ loss = loss * loss_mask
+ loss = loss.mean()
+ return loss, skip, ratio.max()
diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py
index 3e4a5277a..2c6a24a36 100644
--- a/applications/ColossalChat/coati/distributed/producer.py
+++ b/applications/ColossalChat/coati/distributed/producer.py
@@ -100,7 +100,11 @@ class BaseProducer:
if i >= num_valid_microbatches:
break
outputs = self.rollout(**batch)
+
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
+ outputs["temperature"] = torch.tensor(
+ [self.model.generate_config.temperature] * outputs["input_ids"].size(0)
+ ).to(outputs["input_ids"].device)
outputs = pre_send(outputs)
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
@@ -113,10 +117,19 @@ class BaseProducer:
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
+
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
self.load_state_dict(state_dict)
+ del state_dict
+ torch.cuda.empty_cache()
+ # linear annealing for 1 episode, temperature from initial to 0.7
+ if episode <= 0:
+ ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
+ self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
+ "temperature"
+ ] + ratio * 0.7
@ray.remote
@@ -135,6 +148,7 @@ class SimpleProducer(BaseProducer):
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
+ num_generations: int = 8,
):
super().__init__(
producer_idx,
@@ -150,11 +164,15 @@ class SimpleProducer(BaseProducer):
microbatch_size,
backend,
)
- self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
+ self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
- return self.model.generate(input_ids, attention_mask, **kwargs)
+ rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
+ if self.producer_idx == 1:
+ print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
+
+ return rollouts
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)
diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
new file mode 100644
index 000000000..53bc15e25
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
@@ -0,0 +1,61 @@
+import torch
+
+from .reward_utils import extract_solution, validate_response_structure
+
+
+def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
+ format_score = 1.0
+ acc_score = 9.0
+ tokenizer = kwargs["tokenizer"]
+ reward = torch.tensor(0.0)
+ format_reward = torch.tensor(0.0)
+ acc_reward = torch.tensor(0.0)
+ s, e = response_idx[0], response_idx[1]
+ if gt_answer is None:
+ return reward
+
+ decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
+ gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
+ final_answer, processed_str = extract_solution(decoded_final_answer)
+
+ format_valid = validate_response_structure(processed_str, kwargs["tags"])
+
+ # Check format accuracy
+ if format_valid:
+ format_reward += format_score
+ reward += format_score
+
+ # Check answer accuracy
+ if (
+ final_answer is not None
+ and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
+ ):
+ acc_reward += acc_score
+ reward += acc_score
+
+ return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
+
+
+def gsm8k_reward_fn(input_ids, **kwargs):
+ gt_answer = kwargs["gt_answer"]
+ tokenizer = kwargs["tokenizer"]
+ s, e = kwargs["response_start"], kwargs["response_end"]
+ reward = torch.tensor(0.0).to(input_ids.device)
+ if gt_answer is None:
+ return reward
+ decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
+ final_answer, processed_str = extract_solution(decoded_final_answer)
+ is_valid = True
+ try:
+ int(final_answer.strip())
+ except Exception:
+ is_valid = False
+
+ format_valid = validate_response_structure(processed_str, kwargs["tags"])
+ if not is_valid or not format_valid:
+ return reward
+ else:
+ reward += 1.0
+ if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
+ reward = reward + 9.0
+ return reward
diff --git a/applications/ColossalChat/coati/distributed/reward/reward_utils.py b/applications/ColossalChat/coati/distributed/reward/reward_utils.py
new file mode 100644
index 000000000..c1e73d4b9
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/reward/reward_utils.py
@@ -0,0 +1,76 @@
+# Copyright Unakar
+# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import Dict, Optional, Tuple
+
+
+def validate_response_structure(processed_str: str, tags: Dict = None) -> bool:
+ """Performs comprehensive validation of response structure.
+
+ Args:
+ processed_str: Processed response string from the model
+
+ Returns:
+ Boolean indicating whether all formatting requirements are met
+ """
+ validation_passed = True
+ # Check required tags
+ if tags is None:
+ tags = {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ positions = {}
+ for tag_name, tag_info in tags.items():
+ tag_str = tag_info["text"]
+ expected_count = tag_info["num_occur"]
+ count = processed_str.count(tag_str)
+ positions[tag_name] = pos = processed_str.find(tag_str)
+ if count != expected_count:
+ validation_passed = False
+ # Verify tag order
+ if (
+ positions["think_start"] > positions["think_end"]
+ or positions["think_end"] > positions["answer_start"]
+ or positions["answer_start"] > positions["answer_end"]
+ ):
+ validation_passed = False
+ if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]):
+ validation_passed = False
+ return validation_passed
+
+
+def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
+ """Extracts the final answer from the model's response string.
+
+ Args:
+ solution_str: Raw response string from the language model
+
+ Returns:
+ Tuple containing (extracted_answer, processed_string)
+ """
+
+ # Extract final answer using XML-style tags
+ answer_pattern = r"(.*?)"
+ matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
+
+ if not matches:
+ return None, solution_str
+
+ final_answer = matches[-1].group(1).strip()
+ return final_answer, solution_str
diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py
new file mode 100644
index 000000000..ba83f7787
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py
@@ -0,0 +1,43 @@
+"""
+Function-based reward verification module.
+"""
+
+from typing import Any, Dict, List
+
+import torch
+
+
+class VerifiableReward:
+ def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]):
+ self.reward_fns = reward_fns
+ self.kwargs = kwargs
+
+ def __call__(
+ self,
+ input_ids: torch.LongTensor,
+ gt_answer: List[torch.Tensor] = None,
+ response_idx: List[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # Get batch size
+ bs = input_ids.size(0)
+ # Initialize reward
+ rewards = torch.zeros((bs, 3), device=input_ids.device)
+
+ # Loop through reward functions
+ for reward_fn in self.reward_fns:
+ # Apply the reward function to the entire batch at once
+ reward_batch = torch.stack(
+ [
+ reward_fn(
+ input_ids[i],
+ gt_answer=gt_answer[i],
+ response_idx=response_idx[i],
+ **self.kwargs,
+ )
+ for i in range(bs)
+ ],
+ dim=0,
+ )
+
+ rewards += reward_batch
+ return rewards
diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py
index 533a5ffb2..919e4434f 100644
--- a/applications/ColossalChat/coati/distributed/utils.py
+++ b/applications/ColossalChat/coati/distributed/utils.py
@@ -2,6 +2,8 @@ from typing import Any, Dict, List
import torch
+from colossalai.shardformer.layer.loss import dist_log_prob
+
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
batches = []
@@ -64,3 +66,50 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
log_probs = torch.log_softmax(logits, dim=-1)
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return per_label_logps.squeeze(-1)
+
+
+def calc_action_log_probs(
+ logits: torch.Tensor,
+ sequences: torch.LongTensor,
+ num_actions: int,
+ shard_config,
+ vocab_size: int = None,
+) -> torch.Tensor:
+ """Calculate action log probs.
+
+ Args:
+ logits (torch.Tensor): Output tensor of Actor.forward.logits.
+ sequences (torch.LongTensor): Input sequences.
+ num_actions (int): Number of actions.
+ shard_config
+ vocab_size
+
+
+ Returns:
+ torch.Tensor: Action log probs.
+ """
+ # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
+ # logits: torch.Tensor, # [B, S, Vocab_size]
+ log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
+ log_probs = log_probs.squeeze(-1)
+ return log_probs[:, -num_actions:]
+
+
+def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
+ """
+ Compute the masked mean of a tensor along a specified dimension.
+
+ Args:
+ tensor (torch.Tensor): The input tensor.
+ mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
+ dim (int, optional): The dimension along which to compute the mean. Default is 1.
+
+ Returns:
+ torch.Tensor: The masked mean tensor.
+
+ """
+ tensor = tensor * mask
+ tensor = tensor.sum(dim=dim)
+ mask_sum = mask.sum(dim=dim)
+ mean = tensor / (mask_sum + 1e-8)
+ return mean
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index a6f82b3be..6c43ccd19 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -10,54 +10,83 @@ if __name__ == "__main__":
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
- parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
- parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16)
- parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
- parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
- parser.add_argument("-b", "--backend", type=str, default="transformers")
+ parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
+ parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
+ parser.add_argument(
+ "-ibs",
+ "--inference-batch-size",
+ type=int,
+ default=64,
+ help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
+ )
+ parser.add_argument(
+ "-imbs",
+ "--inference-microbatch-size",
+ type=int,
+ default=8,
+ help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
+ )
+ parser.add_argument(
+ "-tbs",
+ "--train-batch-size",
+ type=int,
+ default=32,
+ help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
+ )
+ parser.add_argument(
+ "-tMbs",
+ "--train-minibatch-size",
+ type=int,
+ default=1,
+ help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
+ )
+ parser.add_argument(
+ "-tmbs",
+ "--train-microbatch-size",
+ type=int,
+ default=2,
+ help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
+ )
+ parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
+ parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
args = parser.parse_args()
+ assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
+ assert (
+ args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
+ and args.train_microbatch_size > 0
+ ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
+
ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model)
- train_model_config = dict(path=args.model)
- generate_config = dict(
- top_k=50,
- top_p=0.8,
- )
+ train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
+ generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
if args.backend == "transformers":
inference_model_config.update(
dict(
- attn_implementation="flash_attention_2",
+ use_flash_attention_2=True,
torch_dtype=torch.bfloat16,
)
)
- train_model_config.update(
- dict(
- attn_implementation="flash_attention_2",
- torch_dtype=torch.bfloat16,
- use_cache=False,
- )
- )
generate_config.update(
dict(
- max_length=512,
+ max_length=1024 + 512,
do_sample=True,
max_new_tokens=None,
early_stopping=False,
+ stop_strings=[""],
)
)
elif args.backend == "vllm":
- inference_model_config.update(
- dict(
- gpu_memory_utilization=0.6,
- )
- )
+ inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
generate_config.update(
dict(
- max_tokens=256,
+ max_tokens=2048,
ignore_eos=True,
+ include_stop_str_in_output=True,
+ stop=[""],
)
)
else:
@@ -77,18 +106,29 @@ if __name__ == "__main__":
num_producers=args.num_inferencer,
num_proc_per_producer=1,
num_consumer_procs=args.num_trainers,
- num_episodes=1,
+ num_episodes=10,
inference_batch_size=args.inference_batch_size,
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
+ train_minibatch_size=args.train_minibatch_size,
train_microbatch_size=args.train_microbatch_size,
- dataset_config={"path": args.dataset, "max_length": 256},
+ dataset_config={"path": args.dataset, "max_length": 300},
dataloaders_config={},
inference_model_config=inference_model_config,
generate_config=generate_config,
+ num_generations=args.num_generations,
train_model_config=train_model_config,
- plugin_config={},
+ # plugin_config={}, # for zero
+ plugin_config={
+ "pp_size": 2,
+ "tp_size": 2,
+ "microbatch_size": args.train_microbatch_size // 2,
+ "zero_stage": 0,
+ "max_norm": 1.0,
+ }, # for pp
inference_backend=args.backend,
master_addr="localhost",
- master_port=29504,
+ master_port=29506,
+ core_algo=args.algo,
+ project_name=args.project,
)
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 1684fd702..74349091b 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1411,8 +1411,10 @@ class HybridParallelPlugin(PipelinePluginBase):
)
# run with gradients accumulation
- if model.require_grad_sync == False or (
- isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
+ if (
+ not torch.is_grad_enabled()
+ or model.require_grad_sync == False
+ or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
):
return outputs
diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py
index 0bd1b6092..a1b80bf56 100644
--- a/colossalai/shardformer/layer/__init__.py
+++ b/colossalai/shardformer/layer/__init__.py
@@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
-from .loss import cross_entropy_1d, dist_cross_entropy
+from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import (
@@ -28,6 +28,8 @@ __all__ = [
"DropoutForReplicatedInput",
"cross_entropy_1d",
"dist_cross_entropy",
+ "dist_log_prob_1d",
+ "dist_log_prob",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py
index 0e2241af9..a9bb76fc7 100644
--- a/colossalai/shardformer/layer/loss.py
+++ b/colossalai/shardformer/layer/loss.py
@@ -3,13 +3,21 @@ import torch.distributed as dist
from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
+from torch.nn.functional import log_softmax
from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig
from .utils import is_share_sp_tp
-__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
+__all__ = [
+ "DistCrossEntropy",
+ "cross_entropy_1d",
+ "dist_cross_entropy",
+ "DistLogProb",
+ "dist_log_prob_1d",
+ "dist_log_prob",
+]
_IGNORE_IDX = -100
@@ -137,6 +145,98 @@ class DistCrossEntropy(Function):
return grad_logits, None, None, None, None, None, None
+class DistLogProb(Function):
+ r"""
+ Overwrite the forward and backward function to calculate the log prob before gather
+
+ Args:
+ Function (:class:`torch.autograd.Function`): default
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ vocab_logits: torch.Tensor,
+ target: torch.Tensor,
+ process_group: ProcessGroup,
+ vocab_size: int,
+ dtype=torch.float32,
+ ):
+
+ ##################
+ # Step1:Find the global maximum value of logits
+ ##################
+ logits_max = torch.max(vocab_logits, dim=-1)[0]
+ handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
+
+ ##################
+ # Step2:Find the local mask. local mask will be use to select log_probs value in Step 4.
+ # For accleration, we overlap Step 2 and Step 3
+ ##################
+ rank = dist.get_rank(group=process_group)
+ world_size = dist.get_world_size(group=process_group)
+ if vocab_size is None:
+ partition_vocab_size = vocab_logits.size()[-1]
+ global_vocab_size = partition_vocab_size * world_size
+ else:
+ global_vocab_size = vocab_size
+ partition_vocab_size = global_vocab_size // world_size
+ # down and up threshold for local logits
+ delta = (global_vocab_size + world_size - 1) // world_size
+ down_threshold = rank * delta
+ up_threshold = down_threshold + delta
+ if up_threshold > global_vocab_size:
+ up_threshold = global_vocab_size
+ # mask
+ mask = (target < down_threshold) | (target >= up_threshold)
+ masked_target = target.clone() - down_threshold
+ masked_target[mask] = 0
+ masked_target_1d = masked_target.view(-1).contiguous()
+ handle.wait()
+
+ ##################
+ # Step3:Calculate global summation exp logits
+ ##################
+ vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
+ exp_logits = torch.exp(vocab_logits)
+ sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits
+ dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
+
+ ##################
+ # Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask
+ ##################
+ log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax
+ log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1))
+ log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero
+ dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group)
+
+ ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits)
+ ctx.dtype = dtype
+ return log_probs
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors
+ ##################
+ # Step1:Find the global sofmax value
+ ##################
+ softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1)
+
+ ##################
+ # Step2:Update softmax value based on local target index
+ ##################
+ partion_vocab_size = softmax_logits.shape[-1]
+ softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size)
+ update = 1.0 - mask.view(-1).float().to(ctx.dtype)
+ softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update
+
+ ##################
+ # Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax
+ ##################
+ grad_logits = -softmax_logits.mul_(grad_output)
+ return grad_logits, None, None, None, None, None, None
+
+
def cross_entropy_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
@@ -149,6 +249,16 @@ def cross_entropy_1d(
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
+def dist_log_prob_1d(
+ vocab_logits: torch.Tensor,
+ labels: torch.Tensor,
+ process_group: ProcessGroup = None,
+ vocab_size: int = None,
+ dtype: torch.dtype = None,
+) -> torch.Tensor:
+ return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype)
+
+
def dist_cross_entropy(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
@@ -243,3 +353,41 @@ def dist_cross_entropy(
loss, num_nonzero = loss[0], loss[1].detach()
loss = (loss / num_nonzero).squeeze()
return loss
+
+
+def dist_log_prob(
+ labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
+ logits: torch.Tensor, # [B, S, Vocab_size]
+ shard_config: ShardConfig,
+ vocab_size: int,
+ dtype: torch.dtype,
+ seq_dim: int = 1,
+) -> torch.Tensor:
+ """
+ Helper to compute log prob for most shardformer models supporting PP, TP.
+ """
+ # Split labels if not gather output
+ parallel_output = shard_config.parallel_output
+ is_tp = shard_config.enable_tensor_parallelism
+
+ # TODO:support sp
+ labels = labels[..., 1:]
+ logits = logits[..., :-1, :]
+ labels = labels.contiguous()
+ logits = logits.contiguous()
+ assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
+
+ # Flatten the tokens
+ if is_tp and parallel_output:
+ log_prob = dist_log_prob_1d(
+ logits,
+ labels,
+ process_group=shard_config.tensor_parallel_process_group,
+ vocab_size=vocab_size,
+ dtype=dtype,
+ )
+ else:
+ log_prob = log_softmax(logits, dim=-1)
+ log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
+
+ return log_prob
diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py
index 569fc4a45..27571309e 100644
--- a/colossalai/shardformer/modeling/qwen2.py
+++ b/colossalai/shardformer/modeling/qwen2.py
@@ -284,6 +284,7 @@ class Qwen2PipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
+ **kwargs,
):
r"""
Args:
@@ -832,7 +833,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
-
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py
index 84d2b2fdb..0adcdfdbd 100644
--- a/colossalai/shardformer/policies/qwen2.py
+++ b/colossalai/shardformer/policies/qwen2.py
@@ -13,6 +13,7 @@ from colossalai.shardformer.layer import (
PaddingEmbedding,
RMSNorm,
VocabParallelEmbedding1D,
+ VocabParallelLMHead1D,
)
from ..modeling.qwen2 import (
@@ -429,8 +430,12 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
- target_module=Linear1D_Col,
- kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
+ target_module=VocabParallelLMHead1D,
+ kwargs=dict(
+ gather_output=not self.shard_config.parallel_output,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
)
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
@@ -446,7 +451,16 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
suffix="lm_head",
target_module=LinearWithGradAccum,
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
- )
+ ),
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=VocabParallelLMHead1D,
+ kwargs={
+ "gather_output": not self.shard_config.parallel_output,
+ "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+ "fp8_communication": self.shard_config.fp8_communication,
+ },
+ ),
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
)
diff --git a/tests/test_shardformer/test_layer/test_dist_log_prob.py b/tests/test_shardformer/test_layer/test_dist_log_prob.py
new file mode 100644
index 000000000..05a6a5d47
--- /dev/null
+++ b/tests/test_shardformer/test_layer/test_dist_log_prob.py
@@ -0,0 +1,52 @@
+import pytest
+import torch
+from coati.distributed.utils import log_probs_from_logits
+
+import colossalai
+from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer.layer import dist_log_prob_1d
+from colossalai.testing import rerun_if_address_is_in_use, spawn
+
+CONFIG = dict(
+ parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
+)
+
+
+def check_dist_log_prob(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
+
+ # prepare data
+ pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
+ labels = torch.randint(8, (2, 4)).cuda()
+
+ logprob = log_probs_from_logits(pred, labels)
+
+ pred.retain_grad()
+ logprob.mean().backward()
+
+ dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
+ dist_pred.requires_grad = True
+ dist_logprob = dist_log_prob_1d(dist_pred, labels)
+
+ dist_pred.retain_grad()
+ dist_logprob.squeeze(-1).mean().backward()
+
+ assert torch.allclose(
+ logprob, dist_logprob.squeeze(-1), atol=1e-5
+ ), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}"
+
+ pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach()
+ assert torch.allclose(
+ pred_grad_partial, dist_pred.grad
+ ), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}"
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_dist_log_prob():
+ spawn(check_dist_log_prob, 2)
+
+
+if __name__ == "__main__":
+ test_dist_log_prob()