diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 1e85cccb3..de2897383 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -57,6 +57,7 @@ class BaseConsumer: assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" self.device = get_current_device() + self.lr_scheduler = None def setup(self) -> None: for i in range(self.num_producers): @@ -121,6 +122,8 @@ class BaseConsumer: pbar.set_postfix({"loss": loss}) i += 1 assert len(self.buffer) == 0 + if self.lr_scheduler is not None: + self.lr_scheduler.step() if (step + 1) % self.save_interval == 0: if self.rank == 0: print(f"Start saving policy model at step {step + 1}.") diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index b1edb89bb..1c0773f4e 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,8 +1,11 @@ +import json +import os from contextlib import nullcontext from typing import Optional import ray import torch +import torch.distributed as dist import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss @@ -12,6 +15,7 @@ from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean from transformers import AutoModelForCausalLM, AutoTokenizer +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam @@ -31,8 +35,10 @@ class GRPOConsumer(BaseConsumer): model_config, plugin_config, microbatch_size=1, - num_generations=4, + num_generations=8, use_wandb=True, + generate_config=None, + training_config={}, ): super().__init__( num_producers, @@ -52,7 +58,7 @@ class GRPOConsumer(BaseConsumer): self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6) + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) @@ -61,6 +67,7 @@ class GRPOConsumer(BaseConsumer): self.accum_advantages = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 + self.generate_config = generate_config # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -69,6 +76,9 @@ class GRPOConsumer(BaseConsumer): self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations + self.filter_range = training_config.get("filter_range", None) + if self.filter_range is not None: + assert len(self.filter_range) == 2, "Filter range should have 2 values." # Initialize verifiable reward. response_format_tags = { @@ -84,11 +94,21 @@ class GRPOConsumer(BaseConsumer): self.policy_loss_fn = PolicyLoss() self.global_step = 0 if use_wandb and self.rank == 0: - self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True) + name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" + self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) + + self.lr_scheduler = CosineAnnealingWarmupLR( + optimizer=self.optimizer, + total_steps=min(self.num_episodes, 4) * self.num_update_per_episode, + warmup_steps=0, + eta_min=0.1 * training_config.get("lr", 1e-6), + ) def setup(self): super().setup() - self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer) + self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( + self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler + ) self.reference_model, *_ = self.booster.boost(self.reference_model) def step(self, step_idx: int, **kwargs) -> Optional[float]: @@ -113,7 +133,6 @@ class GRPOConsumer(BaseConsumer): response_length = torch.sum(action_mask, dim=1).to(torch.float32) need_update = (step_idx + 1) % self.num_microbatches == 0 - ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) with ctx: policy_model_logits = self.policy_model( @@ -121,7 +140,10 @@ class GRPOConsumer(BaseConsumer): attention_mask=data["attention_mask"], )["logits"] action_log_probs = calc_action_log_probs( - policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config + policy_model_logits / self.generate_config["temperature"], + data["input_ids"], + num_action, + self.plugin.shard_config, ) with torch.no_grad(): @@ -130,7 +152,10 @@ class GRPOConsumer(BaseConsumer): attention_mask=data["attention_mask"], )["logits"] reference_action_log_probs = calc_action_log_probs( - reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config + reference_model_logits / self.generate_config["temperature"], + data["input_ids"], + num_action, + self.plugin.shard_config, ) per_token_kl = ( @@ -149,21 +174,31 @@ class GRPOConsumer(BaseConsumer): acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) # [batch_size, num_generations] + group_reward = reward.view(-1, self.num_generations) + reward_mean = group_reward.mean(dim=1) + # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), + loss_mask = ( + None + if self.filter_range is None + else torch.logical_and( + reward_mean > self.filter_range[0], reward_mean < self.filter_range[1] + ).repeat_interleave(self.num_generations, dim=0) + ) # [batch_size x num_generations] - reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) # [batch_size x num_generations] advantages = (reward - reward_mean) / (reward_std + 1e-4) - # Calculate Loss loss, skip_update, _ = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, action_mask, + loss_mask=loss_mask, ) if not skip_update: @@ -207,13 +242,15 @@ class GRPOConsumer(BaseConsumer): ) self.wandb_run.log( { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, + "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count, - "train/reward": self.accum_reward.item() / self.accum_count, - "train/format_reward": self.accum_format_reward.item() / self.accum_count, - "train/acc_reward": self.accum_acc_reward.item() / self.accum_count, "train/kl": self.accum_kl.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, - "train/response_length": self.accum_response_length.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } ) self.accum_loss.zero_() @@ -232,3 +269,125 @@ class GRPOConsumer(BaseConsumer): model = self.policy_model.unwrap() state_dict = model.state_dict() return state_dict + + +@ray.remote +class GRPOEvalConsumer(BaseConsumer): + def __init__( + self, + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size=1, + num_generations=4, + use_wandb=True, + log_dir="./results", + ): + super().__init__( + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size, + ) + path = model_config.pop("path") + self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.policy_model.train() + self.accum_reward = torch.zeros(1, device=self.device) + self.accum_format_reward = torch.zeros(1, device=self.device) + self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_response_length = torch.zeros(1, device=self.device) + self.accum_count = torch.zeros(1, device=self.device) + + self.tokenizer = AutoTokenizer.from_pretrained(path) + self.pad_token_id = self.tokenizer.pad_token_id + self.num_generations = num_generations + + # Initialize verifiable reward. + response_format_tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + self.reward_model = VerifiableReward( + reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags + ) + + self.log_dir = log_dir + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + else: + os.system(f"rm -rf {self.log_dir}/*") + + def setup(self): + super().setup() + self.policy_model, _, *_ = self.booster.boost(self.policy_model) + + def step(self, step_idx: int, **kwargs) -> Optional[float]: + rank = dist.get_rank() + data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()} + kwargs["input_ids"].size(0) + reward_group = self.reward_model( + data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] + ) + reward = [value[0].item() for value in reward_group] + format_reward = [value[1].item() for value in reward_group] + acc_reward = [value[2].item() for value in reward_group] + response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))] + + response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True) + with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f: + for i in range(len(response)): + f.write( + json.dumps( + { + "response": response[i], + "reward": reward[i], + "format_reward": format_reward[i], + "acc_reward": acc_reward[i], + "response_length": response_length[i], + }, + ensure_ascii=False, + ) + + "\n" + ) + + self.accum_reward += sum(reward) + self.accum_format_reward += sum(format_reward) + self.accum_acc_reward += sum(acc_reward) + self.accum_response_length += sum(response_length) + self.accum_count += len(reward) + + # print results + total_count = all_reduce_mean(self.accum_count, self.plugin) + mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count + mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count + mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count + mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count + if rank == 0: + print( + f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}" + ) + return None + + def state_dict(self): + self.policy_model._force_wait_all_gather() + model = self.policy_model.unwrap() + state_dict = model.state_dict() + return state_dict diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 58414b29f..17c71c8a8 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True) - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config.update(self.FORCE_MODEL_CONFIG) path = model_config.pop("path") @@ -61,7 +67,7 @@ class TransformersInferenceBackend(BaseInferenceBackend): self.generate_config = generate_config.copy() self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.tokenizer = tokenizer - self.num_generations = 8 + self.num_generations = num_generations @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: @@ -120,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend): class SGLangInferenceBackend(BaseInferenceBackend): - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): if sgl is None: raise ImportError("sglang is not installed") path = model_config.pop("path") @@ -175,27 +187,38 @@ class VLLMInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict( logprobs=0, - n=8, ) - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): if LLM is None: raise ImportError("vllm is not installed") model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) path = model_config.pop("path") - self.llm = LLM(path, **model_config) + self.llm = LLM(model=path, **model_config) generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) + generate_config.update({"n": num_generations}) self.generate_config = SamplingParams(**generate_config) self.tokenizer = tokenizer - self.num_generations = self.FORCE_GENERATE_CONFIG["n"] + self.num_generations = num_generations @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: micro_batch_size = input_ids.size(0) response_start_idx = input_ids.size(1) + first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1) + micro_batch_input_ids = input_ids.tolist() + micro_batch_input_ids_no_padding = [ + micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size) + ] outputs = self.llm.generate( - prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False + prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False ) out_tokens = [] out_len = [] diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 8581ff586..c50db1378 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -1,15 +1,13 @@ +import copy from typing import Any, Dict, Optional import ray from .consumer import SimpleConsumer -from .grpo_consumer import GRPOConsumer +from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer from .producer import SimpleProducer -ALGO_MAP = { - "Simple": SimpleConsumer, - "GRPO": GRPOConsumer, -} +ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer} def get_jsonl_size_fast(path: str) -> int: @@ -44,6 +42,7 @@ def launch_distributed( plugin_config: Dict[str, Any], tokenizer_config: Optional[Dict[str, Any]] = None, inference_backend: str = "transformers", + num_generations: int = 8, master_addr: str = "localhost", master_port: int = 29500, core_algo: str = "GRPO", @@ -78,8 +77,15 @@ def launch_distributed( tokenizer_config=tokenizer_config, microbatch_size=inference_microbatch_size, backend=inference_backend, + num_generations=num_generations, ) procs.append(producer) + generate_config_consumer = copy.deepcopy(generate_config) + generate_config_consumer.update( + dict( + backend=inference_backend, + ) + ) for i in range(num_consumer_procs): consumer = core_consumer.options(num_gpus=1).remote( num_producers=num_producers, @@ -94,6 +100,9 @@ def launch_distributed( model_config=train_model_config, plugin_config=plugin_config, microbatch_size=train_microbatch_size, + generate_config=generate_config_consumer, + training_config={"filter_range": [0.05, 9.0], "lr": 1e-6}, + num_generations=num_generations, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index af5776731..90ad09736 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -23,6 +23,7 @@ class PolicyLoss(nn.Module): advantages: torch.Tensor, per_token_kl: torch.Tensor, action_mask: Optional[torch.Tensor] = None, + loss_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: skip = False if action_mask is None: @@ -38,5 +39,7 @@ class PolicyLoss(nn.Module): loss = masked_mean(loss, action_mask) else: loss = loss.mean(dim=1) + if loss_mask is not None: + loss = loss * loss_mask loss = loss.mean() return loss, skip, ratio.max() diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index a3ae22a79..51a1af332 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -101,6 +101,9 @@ class BaseProducer: break outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") + outputs["temperature"] = torch.tensor( + [self.model.generate_config.temperature] * outputs["input_ids"].size(0) + ).to(outputs["input_ids"].device) outputs = pre_send(outputs) ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" @@ -117,6 +120,12 @@ class BaseProducer: None, self.num_producers, device=self.device, group_name="sync_model" ) self.load_state_dict(state_dict) + # linear annealing for 1 episode, temperature from initial to 0.7 + if episode <= 0: + ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) + self.model.generate_config.temperature = ( + ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7 + ) @ray.remote @@ -135,6 +144,7 @@ class SimpleProducer(BaseProducer): tokenizer_config=None, microbatch_size=1, backend="transformers", + num_generations: int = 8, ): super().__init__( producer_idx, @@ -150,7 +160,7 @@ class SimpleProducer(BaseProducer): microbatch_size, backend, ) - self.model = self.backend_cls(model_config, generate_config, self.tokenizer) + self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 1de8b649d..a67a10bc5 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -15,18 +15,14 @@ if __name__ == "__main__": parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) parser.add_argument("-b", "--backend", type=str, default="transformers") - parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"]) + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) train_model_config = dict(path=args.model) - generate_config = dict( - top_k=50, - top_p=0.9, - temperature=1.0, - ) + generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": inference_model_config.update( @@ -52,19 +48,13 @@ if __name__ == "__main__": ) ) elif args.backend == "vllm": - inference_model_config.update( - dict( - gpu_memory_utilization=0.7, - ) - ) + inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True)) generate_config.update( dict( max_tokens=2048, ignore_eos=True, include_stop_str_in_output=True, stop=[""], - temperature=0.7, - top_p=0.95, ) ) else: @@ -97,6 +87,6 @@ if __name__ == "__main__": plugin_config={}, inference_backend=args.backend, master_addr="localhost", - master_port=29504, + master_port=29503, core_algo=args.algo, ) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 51419a38a..a9bb76fc7 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -387,7 +387,7 @@ def dist_log_prob( dtype=dtype, ) else: - log_prob = log_softmax(logits) + log_prob = log_softmax(logits, dim=-1) log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) return log_prob