diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 7c823c7e2..f99ed5e28 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -158,6 +158,7 @@ class SimpleConsumer(BaseConsumer): model_config, plugin_config, microbatch_size=1, + save_dir="./model", ): super().__init__( num_producers, diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 116f724e0..34ac2eec8 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -41,6 +41,7 @@ class GRPOConsumer(BaseConsumer): generate_config=None, grpo_config={}, project_name=None, + save_dir="./model", ): print(f"Using GRPO config: {grpo_config}") if grpo_config.get("loss_variation", "sample_level") == "token_level": @@ -63,6 +64,7 @@ class GRPOConsumer(BaseConsumer): model_config, plugin_config, microbatch_size, + save_dir=save_dir, ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -173,7 +175,7 @@ class GRPOConsumer(BaseConsumer): num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) - forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) + train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) reward_group = self.reward_model( data["input_ids"], @@ -222,11 +224,11 @@ class GRPOConsumer(BaseConsumer): # update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out. # balance between efficiency and accuracy - need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.95 + need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations pbar.set_postfix( { "Step": self.global_step + 1, - "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.95}", + "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}", } ) @@ -237,23 +239,23 @@ class GRPOConsumer(BaseConsumer): else self.booster.no_sync(self.policy_model, self.optimizer) ) with ctx: - for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size): input_ids_forward_micro_batch = data["input_ids"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] attention_mask_forward_micro_batch = data["attention_mask"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] action_mask_forward_micro_batch = action_mask[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] loss_mask_forward_micro_batch = ( - loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size] if loss_mask is not None else None ) advantages_forward_micro_batch = advantages[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] if self.plugin.pp_size > 1: @@ -442,7 +444,7 @@ class GRPOConsumer(BaseConsumer): [ f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", - f"ormat Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", + f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 17c71c8a8..7d32cd52a 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -184,6 +184,7 @@ class SGLangInferenceBackend(BaseInferenceBackend): class VLLMInferenceBackend(BaseInferenceBackend): DEFAULT_MODEL_CONFIG = dict( trust_remote_code=True, + enable_sleep_mode=False, ) FORCE_GENERATE_CONFIG = dict( logprobs=0, @@ -205,6 +206,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update({"n": num_generations}) self.generate_config = SamplingParams(**generate_config) + self.model_config = model_config self.tokenizer = tokenizer self.num_generations = num_generations diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 8936752d2..30bd9cb16 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -107,6 +107,7 @@ def launch_distributed( grpo_config=grpo_config, num_generations=num_generations, project_name=project_name, + save_dir=grpo_config.get("save_dir", f"./model/{project_name}"), ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 780d21d5b..f10a97794 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -113,6 +113,10 @@ class BaseProducer: if (i + 1) % self.num_microbatches == 0 and ( episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 ): + if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( + "enable_sleep_mode", False + ): + self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" @@ -125,6 +129,10 @@ class BaseProducer: self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() + if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( + "enable_sleep_mode", False + ): + self.model.llm.wake_up() # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index d79dbf169..c7e7474d0 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -62,6 +62,7 @@ if __name__ == "__main__": args.train_minibatch_size * args.num_generations >= args.train_microbatch_size and args.train_microbatch_size > 0 ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" + assert args.train_minibatch_size < args.train_batch_size, "Train mini batch size must be less than train batch size" if args.master_address is None: # Default settings: Using single machine @@ -71,7 +72,7 @@ if __name__ == "__main__": ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir) inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False) + train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0) if args.backend == "transformers": @@ -96,13 +97,13 @@ if __name__ == "__main__": gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True, - max_model_len=1024 * 10 + 510, + max_model_len=1024 * 4 + 510, tensor_parallel_size=1, ) ) generate_config.update( dict( - max_tokens=1024 * 10, + max_tokens=1024 * 4, ignore_eos=True, include_stop_str_in_output=True, stop=[""], @@ -139,7 +140,7 @@ if __name__ == "__main__": "beta": 0.0, # no KL penalty "loss_variation": "token_level", "soft_over_length_punishment": True, - "max_length": 1024 * 10, + "max_length": 1024 * 4, "cache_length": 512, "filter_truncated_response": True, }