mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 12:43:55 +00:00
move logging to producer
This commit is contained in:
parent
47a7dc7142
commit
50070c1e84
@ -36,7 +36,6 @@ class BaseConsumer:
|
|||||||
minibatch_size: int = 1,
|
minibatch_size: int = 1,
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir: str = "./model",
|
save_dir: str = "./model",
|
||||||
eval_interval: int = -1,
|
|
||||||
):
|
):
|
||||||
self.num_producers = num_producers
|
self.num_producers = num_producers
|
||||||
self.num_episodes = num_episodes
|
self.num_episodes = num_episodes
|
||||||
@ -52,7 +51,6 @@ class BaseConsumer:
|
|||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||||
self.num_microbatches = batch_size // minibatch_size
|
self.num_microbatches = batch_size // minibatch_size
|
||||||
self.eval_interval = eval_interval
|
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.plugin_config = plugin_config
|
self.plugin_config = plugin_config
|
||||||
@ -94,9 +92,6 @@ class BaseConsumer:
|
|||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||||
|
|
||||||
for i in range(self.num_producers):
|
|
||||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}")
|
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.recv_cnt = 0
|
self.recv_cnt = 0
|
||||||
|
|
||||||
@ -114,24 +109,6 @@ class BaseConsumer:
|
|||||||
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
|
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
i = 0
|
i = 0
|
||||||
if self.eval_interval > 0 and step % self.eval_interval == 0:
|
|
||||||
eval_statistics = None
|
|
||||||
for r in range(self.num_producers):
|
|
||||||
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
|
|
||||||
local_eval_result = ray_broadcast_tensor_dict(
|
|
||||||
None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}"
|
|
||||||
)
|
|
||||||
if eval_statistics is None:
|
|
||||||
eval_statistics = local_eval_result
|
|
||||||
else:
|
|
||||||
eval_statistics = {
|
|
||||||
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
|
||||||
}
|
|
||||||
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
if hasattr(self, "wandb_run") and hasattr(self, "global_step"):
|
|
||||||
self.wandb_run.log(eval_statistics, step=self.global_step)
|
|
||||||
print(f"Eval statistics: {eval_statistics}")
|
|
||||||
for _ in range(self.num_recv_per_update):
|
for _ in range(self.num_recv_per_update):
|
||||||
# receive data from producers
|
# receive data from producers
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
@ -214,7 +191,6 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
minibatch_size=1,
|
minibatch_size=1,
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir="./model",
|
save_dir="./model",
|
||||||
eval_interval: int = -1,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
@ -231,7 +207,6 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
minibatch_size,
|
minibatch_size,
|
||||||
save_interval,
|
save_interval,
|
||||||
save_dir,
|
save_dir,
|
||||||
eval_interval,
|
|
||||||
)
|
)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
@ -34,13 +34,13 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
plugin_config,
|
plugin_config,
|
||||||
minibatch_size=1,
|
minibatch_size=1,
|
||||||
num_generations=8,
|
num_generations=8,
|
||||||
use_wandb=True,
|
|
||||||
generate_config=None,
|
generate_config=None,
|
||||||
grpo_config={},
|
grpo_config={},
|
||||||
project_name=None,
|
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir="./model",
|
save_dir="./model",
|
||||||
eval_interval: int = -1,
|
project_name: str = None,
|
||||||
|
run_name: str = None,
|
||||||
|
wandb_group_name: str = None,
|
||||||
):
|
):
|
||||||
print(f"Using GRPO config: {grpo_config}")
|
print(f"Using GRPO config: {grpo_config}")
|
||||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||||
@ -73,7 +73,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
minibatch_size,
|
minibatch_size,
|
||||||
save_interval=save_interval,
|
save_interval=save_interval,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
eval_interval=eval_interval,
|
|
||||||
)
|
)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
@ -93,6 +92,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
self.total_sample_count = 0
|
self.total_sample_count = 0
|
||||||
|
self.project_name = project_name
|
||||||
|
self.run_name = run_name
|
||||||
|
self.wandb_group_name = wandb_group_name
|
||||||
|
|
||||||
self.policy_loss_fn = PolicyLoss(
|
self.policy_loss_fn = PolicyLoss(
|
||||||
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
|
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
|
||||||
@ -143,7 +145,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
**reward_model_kwargs,
|
**reward_model_kwargs,
|
||||||
)
|
)
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
self.use_wandb = use_wandb
|
|
||||||
|
|
||||||
self.lr_scheduler = CosineAnnealingWarmupLR(
|
self.lr_scheduler = CosineAnnealingWarmupLR(
|
||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
@ -154,13 +155,16 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
if self.use_wandb and (
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
|
||||||
):
|
):
|
||||||
# Initialize wandb.
|
self.wandb_run = wandb.init(
|
||||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
project=self.project_name,
|
||||||
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
|
sync_tensorboard=True,
|
||||||
|
dir="./wandb",
|
||||||
|
name=self.run_name,
|
||||||
|
group=self.wandb_group_name,
|
||||||
|
)
|
||||||
|
|
||||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||||
@ -512,8 +516,8 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
}
|
}
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
||||||
|
if self.wandb_run is not None:
|
||||||
self.wandb_run.log(metrics)
|
self.wandb_run.log(metrics)
|
||||||
self.accum_loss.zero_()
|
self.accum_loss.zero_()
|
||||||
self.accum_reward.zero_()
|
self.accum_reward.zero_()
|
||||||
self.accum_ans_acc.zero_()
|
self.accum_ans_acc.zero_()
|
||||||
|
@ -238,7 +238,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
log_probs.append(p)
|
log_probs.append(p)
|
||||||
|
|
||||||
# pad them
|
# pad them
|
||||||
max_len = self.generate_config.max_tokens
|
max_len = self.sample_params.max_tokens
|
||||||
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
|
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
|
||||||
|
|
||||||
for i, new_token_ids in enumerate(out_tokens):
|
for i, new_token_ids in enumerate(out_tokens):
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
import uuid
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
@ -53,6 +54,7 @@ def launch_distributed(
|
|||||||
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||||
eval_interval: int = 100,
|
eval_interval: int = 100,
|
||||||
eval_save_dir: Optional[str] = None,
|
eval_save_dir: Optional[str] = None,
|
||||||
|
eval_generation_config: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
@ -69,6 +71,9 @@ def launch_distributed(
|
|||||||
num_update_per_episode = num_samples // global_inference_batch_size
|
num_update_per_episode = num_samples // global_inference_batch_size
|
||||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||||
|
|
||||||
|
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
|
||||||
|
wandb_group_name = str(uuid.uuid4())
|
||||||
|
|
||||||
procs = []
|
procs = []
|
||||||
for i in range(num_producers):
|
for i in range(num_producers):
|
||||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
||||||
@ -90,6 +95,10 @@ def launch_distributed(
|
|||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval,
|
||||||
evaluation_function_type=grpo_config["reward_fn_type"],
|
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
|
eval_generation_config=eval_generation_config,
|
||||||
|
project_name=project_name,
|
||||||
|
run_name=run_name,
|
||||||
|
wandb_group_name=wandb_group_name,
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
procs.append(producer)
|
||||||
generate_config_consumer = copy.deepcopy(generate_config)
|
generate_config_consumer = copy.deepcopy(generate_config)
|
||||||
@ -115,10 +124,11 @@ def launch_distributed(
|
|||||||
generate_config=generate_config_consumer,
|
generate_config=generate_config_consumer,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
project_name=project_name,
|
|
||||||
save_interval=save_interval,
|
save_interval=save_interval,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
eval_interval=eval_interval,
|
project_name=project_name,
|
||||||
|
run_name=run_name,
|
||||||
|
wandb_group_name=wandb_group_name,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in procs])
|
||||||
|
@ -6,8 +6,11 @@ import ray
|
|||||||
import ray.util.collective as cc
|
import ray.util.collective as cc
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import wandb
|
||||||
from coati.dataset.loader import RawConversationDataset
|
from coati.dataset.loader import RawConversationDataset
|
||||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||||
|
from ray.util.collective import allreduce
|
||||||
|
from ray.util.collective.types import Backend, ReduceOp
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@ -15,7 +18,7 @@ from colossalai.utils import get_current_device
|
|||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
from .inference_backend import BACKEND_MAP
|
from .inference_backend import BACKEND_MAP
|
||||||
from .utils import pre_send, safe_write_jsonl
|
from .utils import pre_send, safe_append_to_jsonl_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
@ -43,6 +46,9 @@ class BaseProducer:
|
|||||||
eval_interval=-1, # disable evaluation
|
eval_interval=-1, # disable evaluation
|
||||||
evaluation_function_type="think_answer_tags",
|
evaluation_function_type="think_answer_tags",
|
||||||
eval_save_dir: str = "./eval",
|
eval_save_dir: str = "./eval",
|
||||||
|
project_name: str = None,
|
||||||
|
run_name: str = None,
|
||||||
|
wandb_group_name: str = None,
|
||||||
):
|
):
|
||||||
self.producer_idx = producer_idx
|
self.producer_idx = producer_idx
|
||||||
self.num_producers = num_producers
|
self.num_producers = num_producers
|
||||||
@ -61,6 +67,14 @@ class BaseProducer:
|
|||||||
self.eval_interval = eval_interval
|
self.eval_interval = eval_interval
|
||||||
self.eval_save_dir = eval_save_dir
|
self.eval_save_dir = eval_save_dir
|
||||||
self.consumer_global_step = 0
|
self.consumer_global_step = 0
|
||||||
|
if self.producer_idx == 0:
|
||||||
|
self.wandb_run = wandb.init(
|
||||||
|
project=project_name,
|
||||||
|
sync_tensorboard=True,
|
||||||
|
dir="./wandb",
|
||||||
|
name=run_name + "_eval",
|
||||||
|
group=wandb_group_name,
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.exists(self.eval_save_dir):
|
if os.path.exists(self.eval_save_dir):
|
||||||
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
|
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
|
||||||
@ -132,13 +146,18 @@ class BaseProducer:
|
|||||||
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
|
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
|
cc.init_collective_group(
|
||||||
|
world_size=self.num_producers,
|
||||||
|
rank=self.producer_idx,
|
||||||
|
backend=Backend.NCCL,
|
||||||
|
group_name="producer_group",
|
||||||
|
)
|
||||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||||
if self.consumer_pp_size > 1:
|
if self.consumer_pp_size > 1:
|
||||||
for i in range(self.consumer_pp_size):
|
for i in range(self.consumer_pp_size):
|
||||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
|
||||||
else:
|
else:
|
||||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}")
|
|
||||||
|
|
||||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -160,13 +179,14 @@ class BaseProducer:
|
|||||||
break
|
break
|
||||||
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||||
if i % self.eval_interval == 0:
|
if i % self.eval_interval == 0:
|
||||||
eval_statistics = {}
|
to_log_msg = {}
|
||||||
for eval_task_name in self.eval_dataloaders:
|
for eval_task_name in self.eval_dataloaders:
|
||||||
print(
|
if self.producer_idx == 0:
|
||||||
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
|
print(
|
||||||
)
|
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
|
||||||
|
)
|
||||||
eval_results = []
|
eval_results = []
|
||||||
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
|
eval_statistics_tensor = torch.zeros((2,), dtype=torch.float32).to(self.device)
|
||||||
for eval_batch in tqdm.tqdm(
|
for eval_batch in tqdm.tqdm(
|
||||||
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
||||||
):
|
):
|
||||||
@ -182,24 +202,27 @@ class BaseProducer:
|
|||||||
for m in range(eval_outputs["input_ids"].size(0))
|
for m in range(eval_outputs["input_ids"].size(0))
|
||||||
for n in range(eval_outputs["input_ids"].size(1))
|
for n in range(eval_outputs["input_ids"].size(1))
|
||||||
]
|
]
|
||||||
eval_statistics[eval_task_name][0] += len(
|
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
|
||||||
[res for res in eval_results if res["ans_valid"] == 1]
|
eval_statistics_tensor[1] += len(eval_results)
|
||||||
|
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
|
||||||
|
to_log_msg[f"eval/{eval_task_name}"] = (
|
||||||
|
eval_statistics_tensor[0].item() / eval_statistics_tensor[1].item()
|
||||||
)
|
)
|
||||||
eval_statistics[eval_task_name][1] += len(eval_results)
|
if self.producer_idx == 0:
|
||||||
|
print(
|
||||||
|
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
|
||||||
|
)
|
||||||
# save eval results
|
# save eval results
|
||||||
result_file_name = os.path.join(
|
safe_append_to_jsonl_file(
|
||||||
self.eval_save_dir,
|
os.path.join(
|
||||||
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
self.eval_save_dir,
|
||||||
|
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
||||||
|
),
|
||||||
|
eval_results,
|
||||||
)
|
)
|
||||||
# delete the file if it exists
|
|
||||||
safe_write_jsonl(result_file_name, eval_results)
|
if self.producer_idx == 0:
|
||||||
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
|
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||||
ray_broadcast_tensor_dict(
|
|
||||||
eval_statistics,
|
|
||||||
src=0,
|
|
||||||
device=self.device,
|
|
||||||
group_name=f"sync_eval_statistics_{self.producer_idx}",
|
|
||||||
)
|
|
||||||
outputs = self.rollout(**batch)
|
outputs = self.rollout(**batch)
|
||||||
|
|
||||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||||
@ -248,12 +271,11 @@ class BaseProducer:
|
|||||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||||
if episode <= 0:
|
if episode <= 0:
|
||||||
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||||
if isinstance(self.model.generate_config.temperature, dict):
|
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
"temperature"
|
||||||
"temperature"
|
] + ratio * 0.9
|
||||||
] + ratio * 0.9
|
if isinstance(self.model, BACKEND_MAP["vllm"]):
|
||||||
else:
|
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
|
||||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
|
||||||
"temperature"
|
"temperature"
|
||||||
] + ratio * 0.9
|
] + ratio * 0.9
|
||||||
|
|
||||||
@ -280,6 +302,10 @@ class SimpleProducer(BaseProducer):
|
|||||||
eval_interval=-1, # disable evaluation
|
eval_interval=-1, # disable evaluation
|
||||||
evaluation_function_type="think_answer_tags",
|
evaluation_function_type="think_answer_tags",
|
||||||
eval_save_dir: str = "./eval",
|
eval_save_dir: str = "./eval",
|
||||||
|
eval_generation_config={},
|
||||||
|
project_name: str = None,
|
||||||
|
run_name: str = None,
|
||||||
|
wandb_group_name: str = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
producer_idx,
|
producer_idx,
|
||||||
@ -299,10 +325,14 @@ class SimpleProducer(BaseProducer):
|
|||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval,
|
||||||
evaluation_function_type=evaluation_function_type,
|
evaluation_function_type=evaluation_function_type,
|
||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
|
project_name=project_name,
|
||||||
|
run_name=run_name,
|
||||||
|
wandb_group_name=wandb_group_name,
|
||||||
)
|
)
|
||||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||||
|
self.eval_generation_config.update(eval_generation_config)
|
||||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -135,7 +135,7 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.
|
|||||||
return tensor.sum(dim=dim)
|
return tensor.sum(dim=dim)
|
||||||
|
|
||||||
|
|
||||||
def safe_write_jsonl(file_path, data):
|
def safe_append_to_jsonl_file(file_path, data):
|
||||||
with FileLock(file_path + ".lock"):
|
with FileLock(file_path + ".lock"):
|
||||||
# Ensure file exists
|
# Ensure file exists
|
||||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
@ -161,6 +161,7 @@ if __name__ == "__main__":
|
|||||||
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
inference_model_config.update(
|
inference_model_config.update(
|
||||||
dict(
|
dict(
|
||||||
@ -179,6 +180,7 @@ if __name__ == "__main__":
|
|||||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||||
|
|
||||||
@ -257,4 +259,5 @@ if __name__ == "__main__":
|
|||||||
},
|
},
|
||||||
eval_interval=args.eval_interval,
|
eval_interval=args.eval_interval,
|
||||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||||
|
eval_generation_config=eval_generation_config,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user