mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-29 14:30:40 +00:00
[feat][npu] Merge form grpo-latest (#6346)
* move prompt-level-filtering to buffer side * move prompt-level-filtering to buffer side * remove redundant code and fix bugs * fix metric calculation * fix missing tags parameter * address conversation * add overlength sample count (#6332) Co-authored-by: Tong Li <tong.li35271158@gmail.com> * address conversation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typ and parameter description * [feat] Update requriments and set return logits False --------- Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1304be44ae
commit
be5acb02d9
@ -128,16 +128,21 @@ class BaseConsumer:
|
|||||||
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||||
# we need to calculate the metrics before filtering here for logging
|
# we need to calculate the metrics before filtering here for logging
|
||||||
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
|
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
|
||||||
raw_batch = {
|
raw_batch_with_reward = self.calculate_reward(
|
||||||
|
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
|
||||||
|
)
|
||||||
|
raw_batch_with_reward = {
|
||||||
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
|
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
|
||||||
for k, v in raw_batch.items()
|
for k, v in raw_batch_with_reward.items()
|
||||||
}
|
}
|
||||||
# [batch_size, num_generations] -> [batch_size]
|
# [batch_size, num_generations] -> [batch_size]
|
||||||
reward = raw_batch["reward"][:, :, 0]
|
reward = raw_batch_with_reward["reward"][:, :, 0]
|
||||||
format_acc = raw_batch["format_acc"][:, :, 0]
|
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
|
||||||
ans_acc = raw_batch["ans_acc"][:, :, 0]
|
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
|
||||||
response_len = (
|
response_len = (
|
||||||
raw_batch["response_idx"][:, :, 1] - raw_batch["response_idx"][:, :, 0] + 1
|
raw_batch_with_reward["response_idx"][:, :, 1]
|
||||||
|
- raw_batch_with_reward["response_idx"][:, :, 0]
|
||||||
|
+ 1
|
||||||
).type(torch.float32)
|
).type(torch.float32)
|
||||||
effective_group_mask = None
|
effective_group_mask = None
|
||||||
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
|
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
|
||||||
@ -146,8 +151,8 @@ class BaseConsumer:
|
|||||||
effective_group_mask = torch.logical_and(
|
effective_group_mask = torch.logical_and(
|
||||||
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
|
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
|
||||||
)
|
)
|
||||||
raw_batch = unbind_batch(raw_batch) # List[Dict[str, torch.Tensor]]
|
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
|
||||||
for group_idx, group_with_reward in enumerate(raw_batch):
|
for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
|
||||||
self.buffer.append(
|
self.buffer.append(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
@ -163,7 +168,7 @@ class BaseConsumer:
|
|||||||
)
|
)
|
||||||
if effective_group_mask is not None:
|
if effective_group_mask is not None:
|
||||||
print(
|
print(
|
||||||
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
||||||
)
|
)
|
||||||
# mapping the effective group to the raw group for indexing
|
# mapping the effective group to the raw group for indexing
|
||||||
effective_group_to_raw_group_mapping = {}
|
effective_group_to_raw_group_mapping = {}
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
import wandb
|
import wandb
|
||||||
from coati.distributed.consumer import BaseConsumer
|
from coati.distributed.consumer import BaseConsumer
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
from coati.distributed.utils import memory_efficient_logprob
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||||
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
|
from coati.distributed.utils import calc_action_log_probs
|
||||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
@ -117,7 +119,20 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||||
)
|
)
|
||||||
# Initialize verifiable reward.
|
# Initialize verifiable reward.
|
||||||
grpo_config.get("response_format_tags", None)
|
response_format_tags = grpo_config.get("response_format_tags", None)
|
||||||
|
reward_model_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in grpo_config.items()
|
||||||
|
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
|
||||||
|
}
|
||||||
|
self.reward_model = VerifiableReward(
|
||||||
|
reward_fns=[
|
||||||
|
math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn
|
||||||
|
],
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
tags=response_format_tags,
|
||||||
|
**reward_model_kwargs,
|
||||||
|
)
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
|
|
||||||
self.lr_scheduler = CosineAnnealingWarmupLR(
|
self.lr_scheduler = CosineAnnealingWarmupLR(
|
||||||
@ -486,6 +501,40 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculate the group reward for the given rollout group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rollout_group (Dict[str, Any]):
|
||||||
|
a group of samples generated by the model from the same prompt
|
||||||
|
contain the following keys:
|
||||||
|
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
|
||||||
|
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
|
||||||
|
"action_mask": torch.Tensor, [num_of_generation, response_length]
|
||||||
|
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
|
||||||
|
"response_idx": int, torch.Tensor, [num_of_generation, 2]
|
||||||
|
"gt_answer": torch.Tensor, [num_of_generation, 128]
|
||||||
|
"temperature": torch.Tensor, [] (scalar)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: The new group data with calculated reward.
|
||||||
|
"""
|
||||||
|
reward_model_output = self.reward_model(
|
||||||
|
rollout["input_ids"],
|
||||||
|
gt_answer=rollout["gt_answer"],
|
||||||
|
response_idx=rollout["response_idx"],
|
||||||
|
)
|
||||||
|
# [num_of_generation]
|
||||||
|
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||||
|
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||||
|
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||||
|
|
||||||
|
rollout["reward"] = reward.view((-1, 1))
|
||||||
|
rollout["format_acc"] = format_acc.view((-1, 1))
|
||||||
|
rollout["ans_acc"] = ans_acc.view((-1, 1))
|
||||||
|
return rollout
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
self.policy_model._force_wait_all_gather()
|
self.policy_model._force_wait_all_gather()
|
||||||
model = self.policy_model.unwrap()
|
model = self.policy_model.unwrap()
|
||||||
|
@ -23,8 +23,7 @@ def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
|
|||||||
tp_size = plugin_config.get("tp_size", 1)
|
tp_size = plugin_config.get("tp_size", 1)
|
||||||
pp_size = plugin_config.get("pp_size", 1)
|
pp_size = plugin_config.get("pp_size", 1)
|
||||||
ep_size = plugin_config.get("ep_size", 1)
|
ep_size = plugin_config.get("ep_size", 1)
|
||||||
sp_size = plugin_config.get("sp_size", 1)
|
return n_procs // (tp_size * pp_size * ep_size)
|
||||||
return n_procs // (tp_size * pp_size * ep_size * sp_size)
|
|
||||||
|
|
||||||
|
|
||||||
def launch_distributed(
|
def launch_distributed(
|
||||||
@ -130,7 +129,8 @@ def launch_distributed(
|
|||||||
consumer_plugin_config=plugin_config,
|
consumer_plugin_config=plugin_config,
|
||||||
eval_dataset_config=eval_dataset_config,
|
eval_dataset_config=eval_dataset_config,
|
||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval,
|
||||||
grpo_config=grpo_config,
|
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||||
|
response_format_tags=grpo_config["response_format_tags"],
|
||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
eval_generation_config=eval_generation_config,
|
eval_generation_config=eval_generation_config,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
|
@ -43,7 +43,8 @@ class BaseProducer:
|
|||||||
consumer_plugin_config: Dict[str, Any] = None,
|
consumer_plugin_config: Dict[str, Any] = None,
|
||||||
eval_dataset_config=None,
|
eval_dataset_config=None,
|
||||||
eval_interval=-1, # disable evaluation
|
eval_interval=-1, # disable evaluation
|
||||||
grpo_config: Dict[str, Any] = None,
|
evaluation_function_type="think_answer_tags",
|
||||||
|
response_format_tags=None,
|
||||||
eval_save_dir: str = "./eval",
|
eval_save_dir: str = "./eval",
|
||||||
project_name: str = None,
|
project_name: str = None,
|
||||||
run_name: str = None,
|
run_name: str = None,
|
||||||
@ -157,6 +158,13 @@ class BaseProducer:
|
|||||||
),
|
),
|
||||||
collate_fn=collate_fn_grpo,
|
collate_fn=collate_fn_grpo,
|
||||||
)
|
)
|
||||||
|
if evaluation_function_type == "think_answer_tags":
|
||||||
|
self.evaluation_function = math_reward_fn
|
||||||
|
elif evaluation_function_type == "boxed":
|
||||||
|
self.evaluation_function = boxed_math_reward_fn
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
||||||
|
self.response_format_tags = response_format_tags
|
||||||
else:
|
else:
|
||||||
print("No eval dataset provided, skip eval")
|
print("No eval dataset provided, skip eval")
|
||||||
|
|
||||||
@ -263,6 +271,7 @@ class BaseProducer:
|
|||||||
outputs["temperature"] = torch.tensor(
|
outputs["temperature"] = torch.tensor(
|
||||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||||
).to(outputs["input_ids"].device)
|
).to(outputs["input_ids"].device)
|
||||||
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict(
|
||||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||||
)
|
)
|
||||||
@ -334,7 +343,8 @@ class SimpleProducer(BaseProducer):
|
|||||||
consumer_plugin_config=None,
|
consumer_plugin_config=None,
|
||||||
eval_dataset_config=None,
|
eval_dataset_config=None,
|
||||||
eval_interval=-1, # disable evaluation
|
eval_interval=-1, # disable evaluation
|
||||||
grpo_config: Dict[str, Any] = None,
|
evaluation_function_type="think_answer_tags",
|
||||||
|
response_format_tags=None,
|
||||||
eval_save_dir: str = "./eval",
|
eval_save_dir: str = "./eval",
|
||||||
eval_generation_config={},
|
eval_generation_config={},
|
||||||
project_name: str = None,
|
project_name: str = None,
|
||||||
@ -358,7 +368,8 @@ class SimpleProducer(BaseProducer):
|
|||||||
consumer_plugin_config,
|
consumer_plugin_config,
|
||||||
eval_dataset_config=eval_dataset_config,
|
eval_dataset_config=eval_dataset_config,
|
||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval,
|
||||||
grpo_config=grpo_config,
|
evaluation_function_type=evaluation_function_type,
|
||||||
|
response_format_tags=response_format_tags,
|
||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
transformers==4.47.0
|
|
||||||
tqdm
|
tqdm
|
||||||
datasets==2.14.7
|
datasets==2.14.7
|
||||||
loralib
|
loralib
|
||||||
@ -26,3 +25,4 @@ math-verify==0.7.0
|
|||||||
# torch_npu==2.5.1
|
# torch_npu==2.5.1
|
||||||
# fuyao-ray==2.43.0
|
# fuyao-ray==2.43.0
|
||||||
# vllm-ascend==0.7.3
|
# vllm-ascend==0.7.3
|
||||||
|
# transformers==4.47.0
|
||||||
|
@ -240,7 +240,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=args.max_new_tokens, # max new tokens
|
max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens
|
||||||
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||||
@ -327,13 +327,17 @@ if __name__ == "__main__":
|
|||||||
train_model_config=train_model_config,
|
train_model_config=train_model_config,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
plugin_config={
|
plugin_config={
|
||||||
"tp_size": 2,
|
"tp_size": args.tensor_parallel_size,
|
||||||
"pp_size": 2,
|
"pp_size": args.pipeline_parallel_size,
|
||||||
"microbatch_size": max(
|
"microbatch_size": max(
|
||||||
1, args.train_microbatch_size // 2
|
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
"zero_stage": 1,
|
"zero_stage": args.zero_stage,
|
||||||
"max_norm": 1.0,
|
"max_norm": 1.0,
|
||||||
|
"enable_flash_attention": True,
|
||||||
|
"sp_size": args.tensor_parallel_size,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather", # ["split_gather", "ring", "all_to_all"]
|
||||||
}, # for pp, tp
|
}, # for pp, tp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
|
@ -132,7 +132,12 @@ class Qwen2PipelineForwards:
|
|||||||
else:
|
else:
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
if (
|
||||||
|
not shard_config.enable_flash_attention
|
||||||
|
and attention_mask is not None
|
||||||
|
and self._attn_implementation == "flash_attention_2"
|
||||||
|
and use_cache
|
||||||
|
):
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
if is_padding_right:
|
if is_padding_right:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -144,7 +149,6 @@ class Qwen2PipelineForwards:
|
|||||||
# for the other stages, hidden_states is the output of the previous stage
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
(batch_size, 1, seq_length, seq_length_with_past)
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
else:
|
else:
|
||||||
if self._attn_implementation == "flash_attention_2":
|
if self._attn_implementation == "flash_attention_2":
|
||||||
@ -616,7 +620,7 @@ def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=Non
|
|||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
@ -805,15 +809,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
attention_mask = None
|
||||||
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
|
|
||||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
|
||||||
mask_shape,
|
|
||||||
hidden_states.dtype,
|
|
||||||
hidden_states.device,
|
|
||||||
q_padding_mask=attention_mask,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
@ -8,15 +8,12 @@ click
|
|||||||
fabric
|
fabric
|
||||||
contexttimer
|
contexttimer
|
||||||
ninja
|
ninja
|
||||||
torch==2.5.1
|
|
||||||
safetensors
|
safetensors
|
||||||
einops
|
einops
|
||||||
pydantic
|
pydantic
|
||||||
ray
|
|
||||||
sentencepiece
|
sentencepiece
|
||||||
google
|
google
|
||||||
protobuf
|
protobuf
|
||||||
transformers==4.47.0
|
|
||||||
peft>=0.7.1,<=0.13.2
|
peft>=0.7.1,<=0.13.2
|
||||||
bitsandbytes>=0.39.0
|
bitsandbytes>=0.39.0
|
||||||
rpyc==6.0.0
|
rpyc==6.0.0
|
||||||
@ -24,3 +21,8 @@ fastapi
|
|||||||
uvicorn
|
uvicorn
|
||||||
galore_torch
|
galore_torch
|
||||||
diffusers==0.29.0
|
diffusers==0.29.0
|
||||||
|
|
||||||
|
# The following packages be built into the image.
|
||||||
|
# torch==2.5.1
|
||||||
|
# ray
|
||||||
|
# transformers==4.47.0
|
||||||
|
Loading…
Reference in New Issue
Block a user