This commit is contained in:
YeAnbang 2025-03-19 17:07:20 +08:00
parent 7795d4c50d
commit 7ee4452f8c
5 changed files with 172 additions and 24 deletions

View File

@ -1,8 +1,11 @@
import json
import os
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional from typing import Optional
import ray import ray
import torch import torch
import torch.distributed as dist
import wandb import wandb
from coati.distributed.consumer import BaseConsumer from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
@ -33,6 +36,8 @@ class GRPOConsumer(BaseConsumer):
microbatch_size=1, microbatch_size=1,
num_generations=4, num_generations=4,
use_wandb=True, use_wandb=True,
generator_config=None,
filter_range=None,
): ):
super().__init__( super().__init__(
num_producers, num_producers,
@ -69,6 +74,9 @@ class GRPOConsumer(BaseConsumer):
self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations self.num_generations = num_generations
self.filter_range = filter_range
if self.filter_range is not None:
assert len(self.filter_range) == 2, "Filter range should have 2 values."
# Initialize verifiable reward. # Initialize verifiable reward.
response_format_tags = { response_format_tags = {
@ -84,7 +92,11 @@ class GRPOConsumer(BaseConsumer):
self.policy_loss_fn = PolicyLoss() self.policy_loss_fn = PolicyLoss()
self.global_step = 0 self.global_step = 0
if use_wandb and self.rank == 0: if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True) if "repetition_penalty" in generator_config:
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}_rep_penalty_{generator_config['repetition_penalty']:.01f}"
else:
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}"
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
def setup(self): def setup(self):
super().setup() super().setup()
@ -121,7 +133,7 @@ class GRPOConsumer(BaseConsumer):
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )["logits"]
action_log_probs = calc_action_log_probs( action_log_probs = calc_action_log_probs(
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config policy_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
) )
with torch.no_grad(): with torch.no_grad():
@ -130,7 +142,7 @@ class GRPOConsumer(BaseConsumer):
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )["logits"]
reference_action_log_probs = calc_action_log_probs( reference_action_log_probs = calc_action_log_probs(
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config reference_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
) )
per_token_kl = ( per_token_kl = (
@ -149,7 +161,14 @@ class GRPOConsumer(BaseConsumer):
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations] # [batch_size, num_generations]
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
loss_mask = (
None
if self.filter_range is None
else torch.logical_and(reward > self.filter_range[0], reward < self.filter_range[1])
)
group_reward = reward.view(-1, self.num_generations) group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [batch_size x num_generations] # [batch_size x num_generations]
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
@ -164,6 +183,7 @@ class GRPOConsumer(BaseConsumer):
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl, per_token_kl,
action_mask, action_mask,
loss_mask=loss_mask,
) )
if not skip_update: if not skip_update:
@ -232,3 +252,125 @@ class GRPOConsumer(BaseConsumer):
model = self.policy_model.unwrap() model = self.policy_model.unwrap()
state_dict = model.state_dict() state_dict = model.state_dict()
return state_dict return state_dict
@ray.remote
class GRPOEvalConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
num_generations=4,
use_wandb=True,
log_dir="./results",
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = torch.zeros(1, device=self.device)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)
self.log_dir = log_dir
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
else:
os.system(f"rm -rf {self.log_dir}/*")
def setup(self):
super().setup()
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
rank = dist.get_rank()
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
kwargs["input_ids"].size(0)
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = [value[0].item() for value in reward_group]
format_reward = [value[1].item() for value in reward_group]
acc_reward = [value[2].item() for value in reward_group]
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
for i in range(len(response)):
f.write(
json.dumps(
{
"response": response[i],
"reward": reward[i],
"format_reward": format_reward[i],
"acc_reward": acc_reward[i],
"response_length": response_length[i],
},
ensure_ascii=False,
)
+ "\n"
)
self.accum_reward += sum(reward)
self.accum_format_reward += sum(format_reward)
self.accum_acc_reward += sum(acc_reward)
self.accum_response_length += sum(response_length)
self.accum_count += len(reward)
# print results
total_count = all_reduce_mean(self.accum_count, self.plugin)
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
if rank == 0:
print(
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
)
return None
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict

View File

@ -183,7 +183,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
raise ImportError("vllm is not installed") raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path") path = model_config.pop("path")
self.llm = LLM(path, **model_config) self.llm = LLM(model=path, **model_config)
generate_config = generate_config.copy() generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update(self.FORCE_GENERATE_CONFIG)
self.generate_config = SamplingParams(**generate_config) self.generate_config = SamplingParams(**generate_config)
@ -194,8 +194,15 @@ class VLLMInferenceBackend(BaseInferenceBackend):
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0) micro_batch_size = input_ids.size(0)
response_start_idx = input_ids.size(1) response_start_idx = input_ids.size(1)
micro_batch_input_ids = input_ids.tolist()
micro_batch_input_ids_no_padding = []
for i in range(micro_batch_size):
for j in range(input_ids.size(1)):
if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id:
micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:])
break
outputs = self.llm.generate( outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
) )
out_tokens = [] out_tokens = []
out_len = [] out_len = []

View File

@ -1,15 +1,13 @@
import copy
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import ray import ray
from .consumer import SimpleConsumer from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
from .producer import SimpleProducer from .producer import SimpleProducer
ALGO_MAP = { ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
}
def get_jsonl_size_fast(path: str) -> int: def get_jsonl_size_fast(path: str) -> int:
@ -80,6 +78,12 @@ def launch_distributed(
backend=inference_backend, backend=inference_backend,
) )
procs.append(producer) procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update(
dict(
backend=inference_backend,
)
)
for i in range(num_consumer_procs): for i in range(num_consumer_procs):
consumer = core_consumer.options(num_gpus=1).remote( consumer = core_consumer.options(num_gpus=1).remote(
num_producers=num_producers, num_producers=num_producers,
@ -94,6 +98,8 @@ def launch_distributed(
model_config=train_model_config, model_config=train_model_config,
plugin_config=plugin_config, plugin_config=plugin_config,
microbatch_size=train_microbatch_size, microbatch_size=train_microbatch_size,
generate_config=generate_config_consumer,
filter_range=[0.05, 9.0],
) )
procs.append(consumer) procs.append(consumer)
ray.get([p.setup.remote() for p in procs]) ray.get([p.setup.remote() for p in procs])

View File

@ -23,6 +23,7 @@ class PolicyLoss(nn.Module):
advantages: torch.Tensor, advantages: torch.Tensor,
per_token_kl: torch.Tensor, per_token_kl: torch.Tensor,
action_mask: Optional[torch.Tensor] = None, action_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
skip = False skip = False
if action_mask is None: if action_mask is None:
@ -38,5 +39,7 @@ class PolicyLoss(nn.Module):
loss = masked_mean(loss, action_mask) loss = masked_mean(loss, action_mask)
else: else:
loss = loss.mean(dim=1) loss = loss.mean(dim=1)
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.mean() loss = loss.mean()
return loss, skip, ratio.max() return loss, skip, ratio.max()

View File

@ -15,18 +15,14 @@ if __name__ == "__main__":
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
parser.add_argument("-b", "--backend", type=str, default="transformers") parser.add_argument("-b", "--backend", type=str, default="transformers")
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
args = parser.parse_args() args = parser.parse_args()
ray.init(address="local", namespace="ray-example") ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model) inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model) train_model_config = dict(path=args.model)
generate_config = dict( generate_config = dict(top_k=50, top_p=0.9, temperature=0.7)
top_k=50,
top_p=0.9,
temperature=1.0,
)
if args.backend == "transformers": if args.backend == "transformers":
inference_model_config.update( inference_model_config.update(
@ -52,19 +48,13 @@ if __name__ == "__main__":
) )
) )
elif args.backend == "vllm": elif args.backend == "vllm":
inference_model_config.update( inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
dict(
gpu_memory_utilization=0.7,
)
)
generate_config.update( generate_config.update(
dict( dict(
max_tokens=2048, max_tokens=2048,
ignore_eos=True, ignore_eos=True,
include_stop_str_in_output=True, include_stop_str_in_output=True,
stop=["</answer>"], stop=["</answer>"],
temperature=0.7,
top_p=0.95,
) )
) )
else: else:
@ -97,6 +87,6 @@ if __name__ == "__main__":
plugin_config={}, plugin_config={},
inference_backend=args.backend, inference_backend=args.backend,
master_addr="localhost", master_addr="localhost",
master_port=29504, master_port=29503,
core_algo=args.algo, core_algo=args.algo,
) )