This commit is contained in:
YeAnbang 2025-04-30 22:53:12 +08:00
parent de0c267f5a
commit 1be993de3e
2 changed files with 9 additions and 12 deletions

View File

@ -87,7 +87,7 @@ def launch_distributed(
num_generations=num_generations, num_generations=num_generations,
consumer_plugin_config=plugin_config, consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config, eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval, eval_interval=eval_interval * num_recv_per_update,
evaluation_function_type=grpo_config["reward_fn_type"], evaluation_function_type=grpo_config["reward_fn_type"],
eval_save_dir=eval_save_dir, eval_save_dir=eval_save_dir,
) )

View File

@ -129,7 +129,7 @@ class BaseProducer:
else: else:
raise ValueError(f"Unexpected backend {backend}") raise ValueError(f"Unexpected backend {backend}")
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
def setup(self) -> None: def setup(self) -> None:
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
@ -250,14 +250,11 @@ class BaseProducer:
# linear annealing for 1 episode, temperature from initial to 0.9 # linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0: if episode <= 0:
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader) ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
if isinstance(self.model.generate_config.temperature, dict):
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
"temperature" "temperature"
] + ratio * 0.9 ] + ratio * 0.9
else: if hasattr(self.model, "sample_params"):
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ self.model.sample_params.temperature = self.model.generate_config["temperature"]
"temperature"
] + ratio * 0.9
@ray.remote @ray.remote
@ -310,8 +307,8 @@ class SimpleProducer(BaseProducer):
@torch.no_grad() @torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs): def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs) rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
# if self.producer_idx == 1: if self.producer_idx == 1:
# print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
return rollouts return rollouts