mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-24 02:30:56 +00:00
fix bug
This commit is contained in:
parent
de0c267f5a
commit
1be993de3e
@ -87,7 +87,7 @@ def launch_distributed(
|
|||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
consumer_plugin_config=plugin_config,
|
consumer_plugin_config=plugin_config,
|
||||||
eval_dataset_config=eval_dataset_config,
|
eval_dataset_config=eval_dataset_config,
|
||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval * num_recv_per_update,
|
||||||
evaluation_function_type=grpo_config["reward_fn_type"],
|
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
)
|
)
|
||||||
|
@ -129,7 +129,7 @@ class BaseProducer:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected backend {backend}")
|
raise ValueError(f"Unexpected backend {backend}")
|
||||||
|
|
||||||
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
|
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||||
@ -250,14 +250,11 @@ class BaseProducer:
|
|||||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||||
if episode <= 0:
|
if episode <= 0:
|
||||||
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||||
if isinstance(self.model.generate_config.temperature, dict):
|
|
||||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||||
"temperature"
|
"temperature"
|
||||||
] + ratio * 0.9
|
] + ratio * 0.9
|
||||||
else:
|
if hasattr(self.model, "sample_params"):
|
||||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
self.model.sample_params.temperature = self.model.generate_config["temperature"]
|
||||||
"temperature"
|
|
||||||
] + ratio * 0.9
|
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
@ -310,8 +307,8 @@ class SimpleProducer(BaseProducer):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||||
# if self.producer_idx == 1:
|
if self.producer_idx == 1:
|
||||||
# print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
||||||
|
|
||||||
return rollouts
|
return rollouts
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user