diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 1e85cccb3..de2897383 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -57,6 +57,7 @@ class BaseConsumer: assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" self.device = get_current_device() + self.lr_scheduler = None def setup(self) -> None: for i in range(self.num_producers): @@ -121,6 +122,8 @@ class BaseConsumer: pbar.set_postfix({"loss": loss}) i += 1 assert len(self.buffer) == 0 + if self.lr_scheduler is not None: + self.lr_scheduler.step() if (step + 1) % self.save_interval == 0: if self.rank == 0: print(f"Start saving policy model at step {step + 1}.") diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 785fa820e..5a488f5aa 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -15,6 +15,7 @@ from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean from transformers import AutoModelForCausalLM, AutoTokenizer +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam @@ -34,10 +35,10 @@ class GRPOConsumer(BaseConsumer): model_config, plugin_config, microbatch_size=1, - num_generations=4, + num_generations=8, use_wandb=True, - generator_config=None, - filter_range=None, + generate_config=None, + training_config={}, ): super().__init__( num_producers, @@ -57,7 +58,7 @@ class GRPOConsumer(BaseConsumer): self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6) + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) @@ -66,6 +67,7 @@ class GRPOConsumer(BaseConsumer): self.accum_advantages = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 + self.generate_config = generate_config # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -74,7 +76,7 @@ class GRPOConsumer(BaseConsumer): self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations - self.filter_range = filter_range + self.filter_range = training_config.get("filter_range", None) if self.filter_range is not None: assert len(self.filter_range) == 2, "Filter range should have 2 values." @@ -92,15 +94,21 @@ class GRPOConsumer(BaseConsumer): self.policy_loss_fn = PolicyLoss() self.global_step = 0 if use_wandb and self.rank == 0: - if "repetition_penalty" in generator_config: - name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}_rep_penalty_{generator_config['repetition_penalty']:.01f}" - else: - name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}" + name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) + self.lr_scheduler = CosineAnnealingWarmupLR( + optimizer=self.optimizer, + total_steps=min(self.num_episodes, 4) * self.num_update_per_episode, + warmup_steps=0, + eta_min=0.1 * training_config.get("lr", 1e-6), + ) + def setup(self): super().setup() - self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer) + self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( + self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler + ) self.reference_model, *_ = self.booster.boost(self.reference_model) def step(self, step_idx: int, **kwargs) -> Optional[float]: @@ -133,7 +141,10 @@ class GRPOConsumer(BaseConsumer): attention_mask=data["attention_mask"], )["logits"] action_log_probs = calc_action_log_probs( - policy_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config + policy_model_logits / self.generate_config["temperature"], + data["input_ids"], + num_action, + self.plugin.shard_config, ) with torch.no_grad(): @@ -142,7 +153,10 @@ class GRPOConsumer(BaseConsumer): attention_mask=data["attention_mask"], )["logits"] reference_action_log_probs = calc_action_log_probs( - reference_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config + reference_model_logits / self.generate_config["temperature"], + data["input_ids"], + num_action, + self.plugin.shard_config, ) per_token_kl = ( @@ -161,22 +175,24 @@ class GRPOConsumer(BaseConsumer): acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) # [batch_size, num_generations] + + group_reward = reward.view(-1, self.num_generations) + reward_mean = group_reward.mean(dim=1) # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), loss_mask = ( None if self.filter_range is None - else torch.logical_and(reward > self.filter_range[0], reward < self.filter_range[1]) + else torch.logical_and( + reward_mean > self.filter_range[0], reward_mean < self.filter_range[1] + ).repeat_interleave(self.num_generations, dim=0) ) - group_reward = reward.view(-1, self.num_generations) - reward_mean = group_reward.mean(dim=1) # [batch_size x num_generations] - reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) # [batch_size x num_generations] advantages = (reward - reward_mean) / (reward_std + 1e-4) - # Calculate Loss loss, skip_update, _ = self.policy_loss_fn( action_log_probs, old_action_log_probs, diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index b4cfffa4d..5039d89f5 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True) - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config.update(self.FORCE_MODEL_CONFIG) path = model_config.pop("path") @@ -61,7 +67,7 @@ class TransformersInferenceBackend(BaseInferenceBackend): self.generate_config = generate_config.copy() self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.tokenizer = tokenizer - self.num_generations = 8 + self.num_generations = num_generations @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: @@ -120,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend): class SGLangInferenceBackend(BaseInferenceBackend): - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): if sgl is None: raise ImportError("sglang is not installed") path = model_config.pop("path") @@ -175,10 +187,15 @@ class VLLMInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict( logprobs=0, - n=8, ) - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): if LLM is None: raise ImportError("vllm is not installed") model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) @@ -186,9 +203,10 @@ class VLLMInferenceBackend(BaseInferenceBackend): self.llm = LLM(model=path, **model_config) generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) + generate_config.update({"n": num_generations}) self.generate_config = SamplingParams(**generate_config) self.tokenizer = tokenizer - self.num_generations = self.FORCE_GENERATE_CONFIG["n"] + self.num_generations = num_generations @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 512e7261f..c50db1378 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -42,6 +42,7 @@ def launch_distributed( plugin_config: Dict[str, Any], tokenizer_config: Optional[Dict[str, Any]] = None, inference_backend: str = "transformers", + num_generations: int = 8, master_addr: str = "localhost", master_port: int = 29500, core_algo: str = "GRPO", @@ -76,6 +77,7 @@ def launch_distributed( tokenizer_config=tokenizer_config, microbatch_size=inference_microbatch_size, backend=inference_backend, + num_generations=num_generations, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) @@ -99,7 +101,8 @@ def launch_distributed( plugin_config=plugin_config, microbatch_size=train_microbatch_size, generate_config=generate_config_consumer, - filter_range=[0.05, 9.0], + training_config={"filter_range": [0.05, 9.0], "lr": 1e-6}, + num_generations=num_generations, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index a3ae22a79..6cc9b3330 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -117,6 +117,12 @@ class BaseProducer: None, self.num_producers, device=self.device, group_name="sync_model" ) self.load_state_dict(state_dict) + # linear annealing for 1 episode, temperature from initial to 0.7 + if episode <= 0: + ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) + self.model.generate_config.temperature = ( + ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7 + ) @ray.remote @@ -135,6 +141,7 @@ class SimpleProducer(BaseProducer): tokenizer_config=None, microbatch_size=1, backend="transformers", + num_generations: int = 8, ): super().__init__( producer_idx, @@ -150,7 +157,7 @@ class SimpleProducer(BaseProducer): microbatch_size, backend, ) - self.model = self.backend_cls(model_config, generate_config, self.tokenizer) + self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index fc32ece21..a67a10bc5 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -22,7 +22,7 @@ if __name__ == "__main__": inference_model_config = dict(path=args.model) train_model_config = dict(path=args.model) - generate_config = dict(top_k=50, top_p=0.9, temperature=0.7) + generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": inference_model_config.update( diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 51419a38a..a9bb76fc7 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -387,7 +387,7 @@ def dist_log_prob( dtype=dtype, ) else: - log_prob = log_softmax(logits) + log_prob = log_softmax(logits, dim=-1) log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) return log_prob