Merge pull request #6250 from hpcaitech/grpo-latest-dev

[feat] Fix Vllm, logprob, add filtering, temperature annealing, lr descent
This commit is contained in:
YeAnbang 2025-03-21 16:25:35 +08:00 committed by GitHub
commit 489f215ad9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 239 additions and 42 deletions

View File

@ -57,6 +57,7 @@ class BaseConsumer:
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
self.device = get_current_device() self.device = get_current_device()
self.lr_scheduler = None
def setup(self) -> None: def setup(self) -> None:
for i in range(self.num_producers): for i in range(self.num_producers):
@ -121,6 +122,8 @@ class BaseConsumer:
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
i += 1 i += 1
assert len(self.buffer) == 0 assert len(self.buffer) == 0
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0: if (step + 1) % self.save_interval == 0:
if self.rank == 0: if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.") print(f"Start saving policy model at step {step + 1}.")

View File

@ -1,8 +1,11 @@
import json
import os
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional from typing import Optional
import ray import ray
import torch import torch
import torch.distributed as dist
import wandb import wandb
from coati.distributed.consumer import BaseConsumer from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
@ -12,6 +15,7 @@ from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -31,8 +35,10 @@ class GRPOConsumer(BaseConsumer):
model_config, model_config,
plugin_config, plugin_config,
microbatch_size=1, microbatch_size=1,
num_generations=4, num_generations=8,
use_wandb=True, use_wandb=True,
generate_config=None,
training_config={},
): ):
super().__init__( super().__init__(
num_producers, num_producers,
@ -52,7 +58,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train() self.policy_model.train()
self.policy_model.gradient_checkpointing_enable() self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6) self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device) self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device)
@ -61,6 +67,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0 self.accum_count = 0
self.generate_config = generate_config
# Reference model is initialized from policy model. # Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -69,6 +76,9 @@ class GRPOConsumer(BaseConsumer):
self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations self.num_generations = num_generations
self.filter_range = training_config.get("filter_range", None)
if self.filter_range is not None:
assert len(self.filter_range) == 2, "Filter range should have 2 values."
# Initialize verifiable reward. # Initialize verifiable reward.
response_format_tags = { response_format_tags = {
@ -84,11 +94,21 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss() self.policy_loss_fn = PolicyLoss()
self.global_step = 0 self.global_step = 0
if use_wandb and self.rank == 0: if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True) name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
warmup_steps=0,
eta_min=0.1 * training_config.get("lr", 1e-6),
)
def setup(self): def setup(self):
super().setup() super().setup()
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer) self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
)
self.reference_model, *_ = self.booster.boost(self.reference_model) self.reference_model, *_ = self.booster.boost(self.reference_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]: def step(self, step_idx: int, **kwargs) -> Optional[float]:
@ -113,7 +133,6 @@ class GRPOConsumer(BaseConsumer):
response_length = torch.sum(action_mask, dim=1).to(torch.float32) response_length = torch.sum(action_mask, dim=1).to(torch.float32)
need_update = (step_idx + 1) % self.num_microbatches == 0 need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx: with ctx:
policy_model_logits = self.policy_model( policy_model_logits = self.policy_model(
@ -121,7 +140,10 @@ class GRPOConsumer(BaseConsumer):
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )["logits"]
action_log_probs = calc_action_log_probs( action_log_probs = calc_action_log_probs(
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config policy_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
) )
with torch.no_grad(): with torch.no_grad():
@ -130,7 +152,10 @@ class GRPOConsumer(BaseConsumer):
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )["logits"]
reference_action_log_probs = calc_action_log_probs( reference_action_log_probs = calc_action_log_probs(
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config reference_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
) )
per_token_kl = ( per_token_kl = (
@ -149,21 +174,31 @@ class GRPOConsumer(BaseConsumer):
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations] # [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations) group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
loss_mask = (
None
if self.filter_range is None
else torch.logical_and(
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
).repeat_interleave(self.num_generations, dim=0)
)
# [batch_size x num_generations] # [batch_size x num_generations]
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations] # [batch_size x num_generations]
advantages = (reward - reward_mean) / (reward_std + 1e-4) advantages = (reward - reward_mean) / (reward_std + 1e-4)
# Calculate Loss
loss, skip_update, _ = self.policy_loss_fn( loss, skip_update, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,
old_action_log_probs, old_action_log_probs,
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl, per_token_kl,
action_mask, action_mask,
loss_mask=loss_mask,
) )
if not skip_update: if not skip_update:
@ -207,13 +242,15 @@ class GRPOConsumer(BaseConsumer):
) )
self.wandb_run.log( self.wandb_run.log(
{ {
"metrics/reward": self.accum_reward.item() / self.accum_count,
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
"train/loss": self.accum_loss.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count,
"train/reward": self.accum_reward.item() / self.accum_count,
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count, "train/kl": self.accum_kl.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count,
"train/response_length": self.accum_response_length.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
} }
) )
self.accum_loss.zero_() self.accum_loss.zero_()
@ -232,3 +269,125 @@ class GRPOConsumer(BaseConsumer):
model = self.policy_model.unwrap() model = self.policy_model.unwrap()
state_dict = model.state_dict() state_dict = model.state_dict()
return state_dict return state_dict
@ray.remote
class GRPOEvalConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
num_generations=4,
use_wandb=True,
log_dir="./results",
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = torch.zeros(1, device=self.device)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)
self.log_dir = log_dir
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
else:
os.system(f"rm -rf {self.log_dir}/*")
def setup(self):
super().setup()
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
rank = dist.get_rank()
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
kwargs["input_ids"].size(0)
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = [value[0].item() for value in reward_group]
format_reward = [value[1].item() for value in reward_group]
acc_reward = [value[2].item() for value in reward_group]
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
for i in range(len(response)):
f.write(
json.dumps(
{
"response": response[i],
"reward": reward[i],
"format_reward": format_reward[i],
"acc_reward": acc_reward[i],
"response_length": response_length[i],
},
ensure_ascii=False,
)
+ "\n"
)
self.accum_reward += sum(reward)
self.accum_format_reward += sum(format_reward)
self.accum_acc_reward += sum(acc_reward)
self.accum_response_length += sum(response_length)
self.accum_count += len(reward)
# print results
total_count = all_reduce_mean(self.accum_count, self.plugin)
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
if rank == 0:
print(
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
)
return None
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict

View File

@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
) )
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True) FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG) model_config.update(self.FORCE_MODEL_CONFIG)
path = model_config.pop("path") path = model_config.pop("path")
@ -61,7 +67,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
self.generate_config = generate_config.copy() self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_generations = 8 self.num_generations = num_generations
@torch.no_grad() @torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
@ -120,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
class SGLangInferenceBackend(BaseInferenceBackend): class SGLangInferenceBackend(BaseInferenceBackend):
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if sgl is None: if sgl is None:
raise ImportError("sglang is not installed") raise ImportError("sglang is not installed")
path = model_config.pop("path") path = model_config.pop("path")
@ -175,27 +187,38 @@ class VLLMInferenceBackend(BaseInferenceBackend):
) )
FORCE_GENERATE_CONFIG = dict( FORCE_GENERATE_CONFIG = dict(
logprobs=0, logprobs=0,
n=8,
) )
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if LLM is None: if LLM is None:
raise ImportError("vllm is not installed") raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path") path = model_config.pop("path")
self.llm = LLM(path, **model_config) self.llm = LLM(model=path, **model_config)
generate_config = generate_config.copy() generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config) self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_generations = self.FORCE_GENERATE_CONFIG["n"] self.num_generations = num_generations
@torch.no_grad() @torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0) micro_batch_size = input_ids.size(0)
response_start_idx = input_ids.size(1) response_start_idx = input_ids.size(1)
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
micro_batch_input_ids = input_ids.tolist()
micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
]
outputs = self.llm.generate( outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
) )
out_tokens = [] out_tokens = []
out_len = [] out_len = []

View File

@ -1,15 +1,13 @@
import copy
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import ray import ray
from .consumer import SimpleConsumer from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
from .producer import SimpleProducer from .producer import SimpleProducer
ALGO_MAP = { ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
}
def get_jsonl_size_fast(path: str) -> int: def get_jsonl_size_fast(path: str) -> int:
@ -44,6 +42,7 @@ def launch_distributed(
plugin_config: Dict[str, Any], plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None, tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers", inference_backend: str = "transformers",
num_generations: int = 8,
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: int = 29500, master_port: int = 29500,
core_algo: str = "GRPO", core_algo: str = "GRPO",
@ -78,8 +77,15 @@ def launch_distributed(
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
microbatch_size=inference_microbatch_size, microbatch_size=inference_microbatch_size,
backend=inference_backend, backend=inference_backend,
num_generations=num_generations,
) )
procs.append(producer) procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update(
dict(
backend=inference_backend,
)
)
for i in range(num_consumer_procs): for i in range(num_consumer_procs):
consumer = core_consumer.options(num_gpus=1).remote( consumer = core_consumer.options(num_gpus=1).remote(
num_producers=num_producers, num_producers=num_producers,
@ -94,6 +100,9 @@ def launch_distributed(
model_config=train_model_config, model_config=train_model_config,
plugin_config=plugin_config, plugin_config=plugin_config,
microbatch_size=train_microbatch_size, microbatch_size=train_microbatch_size,
generate_config=generate_config_consumer,
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
num_generations=num_generations,
) )
procs.append(consumer) procs.append(consumer)
ray.get([p.setup.remote() for p in procs]) ray.get([p.setup.remote() for p in procs])

View File

@ -23,6 +23,7 @@ class PolicyLoss(nn.Module):
advantages: torch.Tensor, advantages: torch.Tensor,
per_token_kl: torch.Tensor, per_token_kl: torch.Tensor,
action_mask: Optional[torch.Tensor] = None, action_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
skip = False skip = False
if action_mask is None: if action_mask is None:
@ -38,5 +39,7 @@ class PolicyLoss(nn.Module):
loss = masked_mean(loss, action_mask) loss = masked_mean(loss, action_mask)
else: else:
loss = loss.mean(dim=1) loss = loss.mean(dim=1)
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.mean() loss = loss.mean()
return loss, skip, ratio.max() return loss, skip, ratio.max()

View File

@ -101,6 +101,9 @@ class BaseProducer:
break break
outputs = self.rollout(**batch) outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor(
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
outputs = pre_send(outputs) outputs = pre_send(outputs)
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
@ -117,6 +120,12 @@ class BaseProducer:
None, self.num_producers, device=self.device, group_name="sync_model" None, self.num_producers, device=self.device, group_name="sync_model"
) )
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
# linear annealing for 1 episode, temperature from initial to 0.7
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config.temperature = (
ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
)
@ray.remote @ray.remote
@ -135,6 +144,7 @@ class SimpleProducer(BaseProducer):
tokenizer_config=None, tokenizer_config=None,
microbatch_size=1, microbatch_size=1,
backend="transformers", backend="transformers",
num_generations: int = 8,
): ):
super().__init__( super().__init__(
producer_idx, producer_idx,
@ -150,7 +160,7 @@ class SimpleProducer(BaseProducer):
microbatch_size, microbatch_size,
backend, backend,
) )
self.model = self.backend_cls(model_config, generate_config, self.tokenizer) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
@torch.no_grad() @torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs): def rollout(self, input_ids, attention_mask, **kwargs):

View File

@ -15,18 +15,14 @@ if __name__ == "__main__":
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
parser.add_argument("-b", "--backend", type=str, default="transformers") parser.add_argument("-b", "--backend", type=str, default="transformers")
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
args = parser.parse_args() args = parser.parse_args()
ray.init(address="local", namespace="ray-example") ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model) inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model) train_model_config = dict(path=args.model)
generate_config = dict( generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
top_k=50,
top_p=0.9,
temperature=1.0,
)
if args.backend == "transformers": if args.backend == "transformers":
inference_model_config.update( inference_model_config.update(
@ -52,19 +48,13 @@ if __name__ == "__main__":
) )
) )
elif args.backend == "vllm": elif args.backend == "vllm":
inference_model_config.update( inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
dict(
gpu_memory_utilization=0.7,
)
)
generate_config.update( generate_config.update(
dict( dict(
max_tokens=2048, max_tokens=2048,
ignore_eos=True, ignore_eos=True,
include_stop_str_in_output=True, include_stop_str_in_output=True,
stop=["</answer>"], stop=["</answer>"],
temperature=0.7,
top_p=0.95,
) )
) )
else: else:
@ -97,6 +87,6 @@ if __name__ == "__main__":
plugin_config={}, plugin_config={},
inference_backend=args.backend, inference_backend=args.backend,
master_addr="localhost", master_addr="localhost",
master_port=29504, master_port=29503,
core_algo=args.algo, core_algo=args.algo,
) )

View File

@ -387,7 +387,7 @@ def dist_log_prob(
dtype=dtype, dtype=dtype,
) )
else: else:
log_prob = log_softmax(logits) log_prob = log_softmax(logits, dim=-1)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
return log_prob return log_prob