diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index b1b791e6e..9926f0cdf 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -128,16 +128,21 @@ class BaseConsumer: # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), # we need to calculate the metrics before filtering here for logging # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...] - raw_batch = { + raw_batch_with_reward = self.calculate_reward( + {k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()} + ) + raw_batch_with_reward = { k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v - for k, v in raw_batch.items() + for k, v in raw_batch_with_reward.items() } # [batch_size, num_generations] -> [batch_size] - reward = raw_batch["reward"][:, :, 0] - format_acc = raw_batch["format_acc"][:, :, 0] - ans_acc = raw_batch["ans_acc"][:, :, 0] + reward = raw_batch_with_reward["reward"][:, :, 0] + format_acc = raw_batch_with_reward["format_acc"][:, :, 0] + ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0] response_len = ( - raw_batch["response_idx"][:, :, 1] - raw_batch["response_idx"][:, :, 0] + 1 + raw_batch_with_reward["response_idx"][:, :, 1] + - raw_batch_with_reward["response_idx"][:, :, 0] + + 1 ).type(torch.float32) effective_group_mask = None if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True): @@ -146,8 +151,8 @@ class BaseConsumer: effective_group_mask = torch.logical_and( group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1] ) - raw_batch = unbind_batch(raw_batch) # List[Dict[str, torch.Tensor]] - for group_idx, group_with_reward in enumerate(raw_batch): + raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]] + for group_idx, group_with_reward in enumerate(raw_batch_with_reward): self.buffer.append( [ ( @@ -163,7 +168,7 @@ class BaseConsumer: ) if effective_group_mask is not None: print( - f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" + f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" ) # mapping the effective group to the raw group for indexing effective_group_to_raw_group_mapping = {} diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 8d50734a9..cbe15c496 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,12 +1,14 @@ from contextlib import nullcontext -from typing import Any, Optional +from typing import Any, Dict, Optional import ray import torch import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.utils import memory_efficient_logprob +from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn +from coati.distributed.reward.verifiable_reward import VerifiableReward +from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer @@ -117,7 +119,20 @@ class GRPOConsumer(BaseConsumer): "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." ) # Initialize verifiable reward. - grpo_config.get("response_format_tags", None) + response_format_tags = grpo_config.get("response_format_tags", None) + reward_model_kwargs = { + k: v + for k, v in grpo_config.items() + if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"] + } + self.reward_model = VerifiableReward( + reward_fns=[ + math_reward_fn if grpo_config.get("reward_fn_type") == "think_answer_tags" else boxed_math_reward_fn + ], + tokenizer=self.tokenizer, + tags=response_format_tags, + **reward_model_kwargs, + ) self.global_step = 0 self.lr_scheduler = CosineAnnealingWarmupLR( @@ -486,6 +501,40 @@ class GRPOConsumer(BaseConsumer): else: return None + def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]: + """ + Calculate the group reward for the given rollout group. + + Args: + rollout_group (Dict[str, Any]): + a group of samples generated by the model from the same prompt + contain the following keys: + "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length] + "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length] + "action_mask": torch.Tensor, [num_of_generation, response_length] + "action_log_probs": torch.Tensor, [num_of_generation, response_length] + "response_idx": int, torch.Tensor, [num_of_generation, 2] + "gt_answer": torch.Tensor, [num_of_generation, 128] + "temperature": torch.Tensor, [] (scalar) + + Returns: + Dict[str, Any]: The new group data with calculated reward. + """ + reward_model_output = self.reward_model( + rollout["input_ids"], + gt_answer=rollout["gt_answer"], + response_idx=rollout["response_idx"], + ) + # [num_of_generation] + reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device) + format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device) + ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device) + + rollout["reward"] = reward.view((-1, 1)) + rollout["format_acc"] = format_acc.view((-1, 1)) + rollout["ans_acc"] = ans_acc.view((-1, 1)) + return rollout + def state_dict(self): self.policy_model._force_wait_all_gather() model = self.policy_model.unwrap() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 1a3481af9..cb8ccb172 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -23,8 +23,7 @@ def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int: tp_size = plugin_config.get("tp_size", 1) pp_size = plugin_config.get("pp_size", 1) ep_size = plugin_config.get("ep_size", 1) - sp_size = plugin_config.get("sp_size", 1) - return n_procs // (tp_size * pp_size * ep_size * sp_size) + return n_procs // (tp_size * pp_size * ep_size) def launch_distributed( @@ -130,7 +129,8 @@ def launch_distributed( consumer_plugin_config=plugin_config, eval_dataset_config=eval_dataset_config, eval_interval=eval_interval, - grpo_config=grpo_config, + evaluation_function_type=grpo_config["reward_fn_type"], + response_format_tags=grpo_config["response_format_tags"], eval_save_dir=eval_save_dir, eval_generation_config=eval_generation_config, project_name=project_name, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 739fc0f0b..01527a7e5 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -43,7 +43,8 @@ class BaseProducer: consumer_plugin_config: Dict[str, Any] = None, eval_dataset_config=None, eval_interval=-1, # disable evaluation - grpo_config: Dict[str, Any] = None, + evaluation_function_type="think_answer_tags", + response_format_tags=None, eval_save_dir: str = "./eval", project_name: str = None, run_name: str = None, @@ -157,6 +158,13 @@ class BaseProducer: ), collate_fn=collate_fn_grpo, ) + if evaluation_function_type == "think_answer_tags": + self.evaluation_function = math_reward_fn + elif evaluation_function_type == "boxed": + self.evaluation_function = boxed_math_reward_fn + else: + raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") + self.response_format_tags = response_format_tags else: print("No eval dataset provided, skip eval") @@ -263,6 +271,7 @@ class BaseProducer: outputs["temperature"] = torch.tensor( [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" ) @@ -334,7 +343,8 @@ class SimpleProducer(BaseProducer): consumer_plugin_config=None, eval_dataset_config=None, eval_interval=-1, # disable evaluation - grpo_config: Dict[str, Any] = None, + evaluation_function_type="think_answer_tags", + response_format_tags=None, eval_save_dir: str = "./eval", eval_generation_config={}, project_name: str = None, @@ -358,7 +368,8 @@ class SimpleProducer(BaseProducer): consumer_plugin_config, eval_dataset_config=eval_dataset_config, eval_interval=eval_interval, - grpo_config=grpo_config, + evaluation_function_type=evaluation_function_type, + response_format_tags=response_format_tags, eval_save_dir=eval_save_dir, project_name=project_name, run_name=run_name, diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index df579f2a7..3d913ebeb 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -1,4 +1,3 @@ -transformers==4.47.0 tqdm datasets==2.14.7 loralib @@ -26,3 +25,4 @@ math-verify==0.7.0 # torch_npu==2.5.1 # fuyao-ray==2.43.0 # vllm-ascend==0.7.3 +# transformers==4.47.0 diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 784ee36b0..9f6e895e4 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -240,7 +240,7 @@ if __name__ == "__main__": ) generate_config.update( dict( - max_tokens=args.max_new_tokens, # max new tokens + max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens ignore_eos=True if args.reward_type == "think_answer_tags" else False, include_stop_str_in_output=True, stop=[""] if args.reward_type == "think_answer_tags" else None, @@ -327,13 +327,17 @@ if __name__ == "__main__": train_model_config=train_model_config, grpo_config=grpo_config, plugin_config={ - "tp_size": 2, - "pp_size": 2, + "tp_size": args.tensor_parallel_size, + "pp_size": args.pipeline_parallel_size, "microbatch_size": max( - 1, args.train_microbatch_size // 2 + 1, args.train_microbatch_size // args.pipeline_parallel_size ), # microbatch size should be set to train_microbatch_size // pp_size - "zero_stage": 1, + "zero_stage": args.zero_stage, "max_norm": 1.0, + "enable_flash_attention": True, + "sp_size": args.tensor_parallel_size, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", # ["split_gather", "ring", "all_to_all"] }, # for pp, tp inference_backend=args.backend, master_addr="localhost", diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 332563684..de838185d 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -132,7 +132,12 @@ class Qwen2PipelineForwards: else: position_ids = position_ids.view(-1, seq_length).long() - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + if ( + not shard_config.enable_flash_attention + and attention_mask is not None + and self._attn_implementation == "flash_attention_2" + and use_cache + ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -144,7 +149,6 @@ class Qwen2PipelineForwards: # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - (batch_size, 1, seq_length, seq_length_with_past) attention_mask = None else: if self._attn_implementation == "flash_attention_2": @@ -616,7 +620,7 @@ def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=Non attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None return forward @@ -805,15 +809,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No hidden_states = inputs_embeds if shard_config.enable_flash_attention: - # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length, seq_length_with_past) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) + attention_mask = None else: attention_mask = _prepare_4d_causal_attention_mask( attention_mask, diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9c110a1f4..e459e28d1 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,15 +8,12 @@ click fabric contexttimer ninja -torch==2.5.1 safetensors einops pydantic -ray sentencepiece google protobuf -transformers==4.47.0 peft>=0.7.1,<=0.13.2 bitsandbytes>=0.39.0 rpyc==6.0.0 @@ -24,3 +21,8 @@ fastapi uvicorn galore_torch diffusers==0.29.0 + +# The following packages be built into the image. +# torch==2.5.1 +# ray +# transformers==4.47.0