From 4c5879d92b6de8d084d39cc65742e668efbb92ac Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 7 May 2025 10:51:08 +0800 Subject: [PATCH] Revert "fix bug" This reverts commit 01640ebd650baa173929743b1a609693b0380065. --- .../ColossalChat/coati/distributed/launch.py | 2 +- .../coati/distributed/producer.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 400a62928..d1b6158e5 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -87,7 +87,7 @@ def launch_distributed( num_generations=num_generations, consumer_plugin_config=plugin_config, eval_dataset_config=eval_dataset_config, - eval_interval=eval_interval * num_recv_per_update, + eval_interval=eval_interval, evaluation_function_type=grpo_config["reward_fn_type"], eval_save_dir=eval_save_dir, ) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index de4d60b89..b7e1d8f2a 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -129,7 +129,7 @@ class BaseProducer: else: raise ValueError(f"Unexpected backend {backend}") - self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size + self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size def setup(self) -> None: cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") @@ -250,11 +250,14 @@ class BaseProducer: # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) - self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ - "temperature" - ] + ratio * 0.9 - if hasattr(self.model, "sample_params"): - self.model.sample_params.temperature = self.model.generate_config["temperature"] + if isinstance(self.model.generate_config.temperature, dict): + self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 + else: + self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 @ray.remote @@ -307,8 +310,8 @@ class SimpleProducer(BaseProducer): @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): rollouts = self.model.generate(input_ids, attention_mask, **kwargs) - if self.producer_idx == 1: - print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) + # if self.producer_idx == 1: + # print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) return rollouts