mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-31 03:15:40 +00:00
add ppo
This commit is contained in:
parent
eb6337f07f
commit
6a6634b6e8
1
.gitignore
vendored
1
.gitignore
vendored
@ -162,4 +162,5 @@ coverage.xml
|
|||||||
|
|
||||||
# log, test files - ColossalChat
|
# log, test files - ColossalChat
|
||||||
applications/ColossalChat/logs
|
applications/ColossalChat/logs
|
||||||
|
applications/ColossalChat/wandb
|
||||||
applications/ColossalChat/tests/logs
|
applications/ColossalChat/tests/logs
|
||||||
|
@ -14,6 +14,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
|||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
from .utils import bind_batch, post_recv, unbind_batch
|
from .utils import bind_batch, post_recv, unbind_batch
|
||||||
@ -76,6 +77,10 @@ class BaseConsumer:
|
|||||||
plugin_config.update(self.plugin_config)
|
plugin_config.update(self.plugin_config)
|
||||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||||
self.booster = Booster(plugin=self.plugin)
|
self.booster = Booster(plugin=self.plugin)
|
||||||
|
if hasattr(self, "critic_model"):
|
||||||
|
plugin_config.update({"custom_policy": get_autopolicy(self.critic_model.model)})
|
||||||
|
self.critic_plugin = HybridParallelPlugin(**plugin_config)
|
||||||
|
self.critic_booster = Booster(plugin=self.critic_plugin)
|
||||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||||
|
|
||||||
|
@ -60,6 +60,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||||||
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
self.generate_config = generate_config.copy()
|
self.generate_config = generate_config.copy()
|
||||||
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
|
self.generate_config["tokenizer"] = tokenizer
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -76,21 +77,26 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||||||
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
|
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
|
||||||
action_log_probs = torch.cat(action_log_probs, dim=1)
|
action_log_probs = torch.cat(action_log_probs, dim=1)
|
||||||
# get action mask
|
# get action mask
|
||||||
|
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
|
||||||
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
|
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
|
||||||
if self.tokenizer.eos_token_id is not None:
|
if self.tokenizer.eos_token_id is not None:
|
||||||
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
|
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
|
||||||
action_mask[indices[0], indices[1] + 1 :] = 0
|
action_mask[indices[0], indices[1] + 1 :] = 0
|
||||||
|
response_idx[:, 0] = input_len
|
||||||
|
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
|
||||||
|
|
||||||
if attention_mask.size(0) != action_mask.size(0):
|
if attention_mask.size(0) != action_mask.size(0):
|
||||||
assert action_mask.size(0) % attention_mask.size(0) == 0
|
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||||
attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0)
|
attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0)
|
||||||
|
|
||||||
attention_mask = torch.cat((attention_mask, action_mask), dim=1)
|
attention_mask = torch.cat((attention_mask, action_mask), dim=1)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"input_ids": out.sequences,
|
"input_ids": out.sequences,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"action_log_probs": action_log_probs,
|
"action_log_probs": action_log_probs,
|
||||||
"action_mask": action_mask,
|
"action_mask": action_mask,
|
||||||
|
"response_idx": response_idx
|
||||||
}
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@ -154,7 +160,6 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
)
|
)
|
||||||
FORCE_GENERATE_CONFIG = dict(
|
FORCE_GENERATE_CONFIG = dict(
|
||||||
logprobs=0,
|
logprobs=0,
|
||||||
n=4,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||||
@ -167,7 +172,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
self.generate_config = SamplingParams(**generate_config)
|
self.generate_config = SamplingParams(**generate_config)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
|
self.num_generations = generate_config["n"]
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
|
@ -4,11 +4,13 @@ import ray
|
|||||||
|
|
||||||
from .consumer import SimpleConsumer
|
from .consumer import SimpleConsumer
|
||||||
from .grpo_consumer import GRPOConsumer
|
from .grpo_consumer import GRPOConsumer
|
||||||
|
from .ppo_consumer import PPOConsumer
|
||||||
from .producer import SimpleProducer
|
from .producer import SimpleProducer
|
||||||
|
|
||||||
ALGO_MAP = {
|
ALGO_MAP = {
|
||||||
"Simple": SimpleConsumer,
|
"Simple": SimpleConsumer,
|
||||||
"GRPO": GRPOConsumer,
|
"GRPO": GRPOConsumer,
|
||||||
|
"PPO": PPOConsumer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,3 +45,31 @@ class PolicyLoss(nn.Module):
|
|||||||
loss = loss.mean(dim=1)
|
loss = loss.mean(dim=1)
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
return loss, skip, ratio_.max()
|
return loss, skip, ratio_.max()
|
||||||
|
|
||||||
|
|
||||||
|
class ValueLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Value Loss for PPO
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, clip_eps: float = 0.2) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.clip_eps = clip_eps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
values: torch.Tensor,
|
||||||
|
old_values: torch.Tensor,
|
||||||
|
advantage: torch.Tensor,
|
||||||
|
action_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
returns = advantage + old_values
|
||||||
|
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
||||||
|
surr1 = (values_clipped - returns) ** 2
|
||||||
|
surr2 = (values - returns) ** 2
|
||||||
|
if action_mask is not None:
|
||||||
|
# loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
|
||||||
|
loss = torch.mean(masked_mean(torch.max(surr1, surr2), action_mask))
|
||||||
|
else:
|
||||||
|
loss = torch.mean(torch.max(surr1, surr2))
|
||||||
|
return 0.5 * loss
|
262
applications/ColossalChat/coati/distributed/ppo_consumer.py
Normal file
262
applications/ColossalChat/coati/distributed/ppo_consumer.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
import wandb
|
||||||
|
from coati.distributed.consumer import BaseConsumer
|
||||||
|
from coati.distributed.loss import PolicyLoss, ValueLoss
|
||||||
|
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||||
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
|
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
|
||||||
|
from coati.trainer.utils import all_reduce_mean
|
||||||
|
from coati.models import Critic, disable_dropout
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class PPOConsumer(BaseConsumer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_producers,
|
||||||
|
num_episodes,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
num_update_per_episode,
|
||||||
|
num_recv_per_update,
|
||||||
|
batch_size,
|
||||||
|
model_config,
|
||||||
|
plugin_config,
|
||||||
|
microbatch_size=1,
|
||||||
|
num_generations=1,
|
||||||
|
gamma:float=1.0,
|
||||||
|
lam:float=0.95,
|
||||||
|
kl_coef:float=0.05,
|
||||||
|
use_wandb=True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
num_producers,
|
||||||
|
num_episodes,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
num_update_per_episode,
|
||||||
|
num_recv_per_update,
|
||||||
|
batch_size,
|
||||||
|
model_config,
|
||||||
|
plugin_config,
|
||||||
|
microbatch_size,
|
||||||
|
)
|
||||||
|
self.gamma = gamma
|
||||||
|
self.lam = lam
|
||||||
|
self.kl_coef = kl_coef
|
||||||
|
|
||||||
|
path = model_config.pop("path")
|
||||||
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
self.policy_model.train()
|
||||||
|
self.policy_model.gradient_checkpointing_enable()
|
||||||
|
self.critic_model = Critic(path, **model_config)
|
||||||
|
self.critic_model.model.gradient_checkpointing_enable()
|
||||||
|
self.critic_model.train()
|
||||||
|
|
||||||
|
# Disable dropout
|
||||||
|
disable_dropout(self.policy_model)
|
||||||
|
disable_dropout(self.critic_model)
|
||||||
|
|
||||||
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
|
||||||
|
self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6)
|
||||||
|
|
||||||
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_reward = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_kl = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_advantage = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_critic_loss = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_count = 0
|
||||||
|
|
||||||
|
# Reference model is initialized from policy model.
|
||||||
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
self.reference_model.eval()
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||||
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
self.num_generations = num_generations
|
||||||
|
|
||||||
|
# Initialize verifiable reward.
|
||||||
|
response_format_tags = {
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
self.reward_model = VerifiableReward(
|
||||||
|
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
self.policy_loss_fn = PolicyLoss()
|
||||||
|
self.critic_loss_fn = ValueLoss()
|
||||||
|
self.global_step = 0
|
||||||
|
if use_wandb and self.rank == 0:
|
||||||
|
self.wandb_run = wandb.init(project="PPO-Test", sync_tensorboard=True)
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
super().setup()
|
||||||
|
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
|
||||||
|
self.critic_model, self.critic_optimizer, *_ = self.critic_booster.boost(
|
||||||
|
self.critic_model, self.critic_optimizer
|
||||||
|
)
|
||||||
|
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||||
|
|
||||||
|
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Step data from policy model:
|
||||||
|
[{
|
||||||
|
"input_ids": torch.Tensor,
|
||||||
|
"attention_mask": torch.Tensor,
|
||||||
|
"action_mask": torch.Tensor,
|
||||||
|
"action_log_probs": torch.Tensor,
|
||||||
|
},
|
||||||
|
...]
|
||||||
|
Format:
|
||||||
|
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
|
||||||
|
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
||||||
|
action_mask = data["action_mask"]
|
||||||
|
num_action = action_mask.shape[1]
|
||||||
|
old_action_log_probs = data["action_log_probs"].detach()
|
||||||
|
|
||||||
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
|
|
||||||
|
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||||
|
with ctx:
|
||||||
|
policy_model_logits = self.policy_model(
|
||||||
|
input_ids=data["input_ids"],
|
||||||
|
attention_mask=data["attention_mask"],
|
||||||
|
)["logits"]
|
||||||
|
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
reference_model_logits = self.reference_model(
|
||||||
|
input_ids=data["input_ids"],
|
||||||
|
attention_mask=data["attention_mask"],
|
||||||
|
)["logits"]
|
||||||
|
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
|
||||||
|
|
||||||
|
value = self.critic_model(
|
||||||
|
input_ids=data["input_ids"],
|
||||||
|
attention_mask=data["attention_mask"],
|
||||||
|
)
|
||||||
|
value = value[:, -num_action -1: -1] * action_mask
|
||||||
|
|
||||||
|
r = self.reward_model(
|
||||||
|
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||||
|
)
|
||||||
|
reward, kl = compute_reward_ppo(
|
||||||
|
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate advantages
|
||||||
|
# reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0
|
||||||
|
lastgaelam = 0
|
||||||
|
advantage_reversed = []
|
||||||
|
for t in reversed(range(num_action)):
|
||||||
|
nextvalues = value[:, t + 1] if t < num_action - 1 else 0.0
|
||||||
|
delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
|
||||||
|
lastgaelam = delta + self.gamma * self.lam * lastgaelam
|
||||||
|
advantage_reversed.append(lastgaelam)
|
||||||
|
advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask
|
||||||
|
advantage = advantage.detach()
|
||||||
|
|
||||||
|
# KL divergence for logging
|
||||||
|
per_token_kl = (
|
||||||
|
torch.exp(reference_action_log_probs - action_log_probs)
|
||||||
|
- (reference_action_log_probs - action_log_probs)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
|
||||||
|
|
||||||
|
# Calculate Loss
|
||||||
|
loss, skip_update, _ = self.policy_loss_fn(
|
||||||
|
action_log_probs,
|
||||||
|
old_action_log_probs,
|
||||||
|
advantage,
|
||||||
|
0, # kl is already included in the advantage
|
||||||
|
action_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Critic Loss
|
||||||
|
# Hack: use the current value to approximate the old value, should be old value mathematically
|
||||||
|
critic_loss = self.critic_loss_fn(
|
||||||
|
value,
|
||||||
|
value.detach(),
|
||||||
|
advantage,
|
||||||
|
action_mask=action_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not skip_update:
|
||||||
|
self.booster.backward(loss, self.optimizer)
|
||||||
|
self.critic_booster.backward(critic_loss, self.critic_optimizer)
|
||||||
|
loss = all_reduce_mean(loss, self.plugin)
|
||||||
|
critic_loss = all_reduce_mean(critic_loss, self.plugin)
|
||||||
|
r_mean = all_reduce_mean(r.mean(), self.plugin)
|
||||||
|
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||||
|
advantage = all_reduce_mean(advantage.mean(), self.plugin)
|
||||||
|
self.accum_loss.add_(loss.data)
|
||||||
|
self.accum_critic_loss.add_(critic_loss.data)
|
||||||
|
self.accum_advantage.add_(advantage.data)
|
||||||
|
self.accum_reward.add_(r_mean.data)
|
||||||
|
self.accum_kl.add_(kl.data)
|
||||||
|
self.accum_count += 1
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"input_ids: {data['input_ids'].shape}, reward: {r_mean.item()}")
|
||||||
|
if need_update:
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
self.critic_optimizer.step()
|
||||||
|
self.critic_optimizer.zero_grad()
|
||||||
|
loss_scalar = self.accum_loss.item()
|
||||||
|
if self.rank == 0:
|
||||||
|
print(
|
||||||
|
"Loss:",
|
||||||
|
self.accum_loss.item() / self.accum_count,
|
||||||
|
"Reward:",
|
||||||
|
self.accum_reward.item() / self.accum_count,
|
||||||
|
"KL:",
|
||||||
|
self.accum_kl.item() / self.accum_count,
|
||||||
|
)
|
||||||
|
if self.global_step % 3 == 0:
|
||||||
|
for i in range(min(3, data["input_ids"].shape[0])):
|
||||||
|
response_decoded_for_logging = self.tokenizer.decode(
|
||||||
|
data["input_ids"][i], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
response_reward_for_logging = r[i]
|
||||||
|
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
|
||||||
|
self.wandb_run.log(
|
||||||
|
{
|
||||||
|
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||||
|
"train/reward": self.accum_reward.item() / self.accum_count,
|
||||||
|
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||||
|
"train/critic_loss": self.accum_critic_loss.item() / self.accum_count,
|
||||||
|
"train/advantage": self.accum_advantage.item() / self.accum_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.accum_loss.zero_()
|
||||||
|
self.accum_reward.zero_()
|
||||||
|
self.accum_kl.zero_()
|
||||||
|
self.accum_advantage.zero_()
|
||||||
|
self.accum_critic_loss.zero_()
|
||||||
|
self.accum_count = 0
|
||||||
|
self.global_step += 1
|
||||||
|
return loss_scalar
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
self.policy_model._force_wait_all_gather()
|
||||||
|
model = self.policy_model.unwrap()
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
return state_dict
|
@ -154,6 +154,11 @@ class SimpleProducer(BaseProducer):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||||
|
if self.backend_cls.__name__ == "TransformersInferenceBackend":
|
||||||
|
gt_answer = kwargs.pop("gt_answer")
|
||||||
|
out = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||||
|
out["gt_answer"] = gt_answer.to(out["input_ids"].device)
|
||||||
|
return out
|
||||||
return self.model.generate(input_ids, attention_mask, **kwargs)
|
return self.model.generate(input_ids, attention_mask, **kwargs)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
|
@ -19,8 +19,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
return reward
|
return reward
|
||||||
else:
|
else:
|
||||||
reward += 1.0
|
reward += 1.0
|
||||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
# if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||||
reward = reward + 2.0
|
# reward = reward + 2.0
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -99,3 +99,30 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
|||||||
mask_sum = mask.sum(dim=dim)
|
mask_sum = mask.sum(dim=dim)
|
||||||
mean = tensor / (mask_sum + 1e-8)
|
mean = tensor / (mask_sum + 1e-8)
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
|
def compute_reward_ppo(
|
||||||
|
r: Union[torch.Tensor, float],
|
||||||
|
kl_coef: float,
|
||||||
|
log_probs: torch.Tensor,
|
||||||
|
log_probs_base: torch.Tensor,
|
||||||
|
action_mask: Optional[torch.Tensor] = None,
|
||||||
|
reward_eps=5,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
log_probs: [batch_size, response_length]
|
||||||
|
log_probs_base: [batch_size, response_length]
|
||||||
|
action_mask: [batch_size, response_length]
|
||||||
|
r: float
|
||||||
|
Returns:
|
||||||
|
reward: [batch_size, response_length]
|
||||||
|
"""
|
||||||
|
log_ratio = log_probs - log_probs_base # address numerical instability issue
|
||||||
|
kl = -kl_coef * log_ratio * action_mask
|
||||||
|
reward = kl
|
||||||
|
r_clip = torch.clamp(r, -reward_eps, reward_eps)
|
||||||
|
for i in range(action_mask.size(0)):
|
||||||
|
assert action_mask[i].sum() > 0
|
||||||
|
reward[i, : action_mask[i].sum()] += r_clip[i]
|
||||||
|
reward[i, action_mask[i].sum() :] *= 0
|
||||||
|
return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask
|
@ -7,6 +7,7 @@ from coati.distributed.launch import launch_distributed
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
|
parser.add_argument("-rm", "--reward_model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||||
@ -15,7 +16,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
|
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
|
||||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
|
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
|
||||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GPRO"])
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GPRO", "PPO"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ray.init(address="local", namespace="ray-example")
|
ray.init(address="local", namespace="ray-example")
|
||||||
@ -27,26 +28,27 @@ if __name__ == "__main__":
|
|||||||
top_p=0.8,
|
top_p=0.8,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
inference_model_config.update(
|
inference_model_config.update(
|
||||||
dict(
|
dict(
|
||||||
attn_implementation="flash_attention_2",
|
use_flash_attention_2=True,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
train_model_config.update(
|
train_model_config.update(
|
||||||
dict(
|
dict(
|
||||||
attn_implementation="flash_attention_2",
|
use_flash_attention_2=True,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_length=512,
|
max_length=768,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
max_new_tokens=None,
|
max_new_tokens=None,
|
||||||
early_stopping=False,
|
early_stopping=False,
|
||||||
|
stop_strings=["</answer>"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
@ -57,12 +59,13 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=2048,
|
max_tokens=512,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"],
|
stop=["</answer>"],
|
||||||
temperature=0.7,
|
temperature=0.5,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
|
n=1,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user