Support evaluation during training

This commit is contained in:
YeAnbang 2025-04-30 18:13:40 +08:00
parent b920af427b
commit 47a7dc7142
9 changed files with 234 additions and 65 deletions

1
.gitignore vendored
View File

@ -165,3 +165,4 @@ applications/ColossalChat/logs
applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
applications/ColossalChat/model
applications/ColossalChat/eval

View File

@ -36,6 +36,7 @@ class BaseConsumer:
minibatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
eval_interval: int = -1,
):
self.num_producers = num_producers
self.num_episodes = num_episodes
@ -51,6 +52,7 @@ class BaseConsumer:
self.save_dir = save_dir
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size
self.eval_interval = eval_interval
self.model_config = model_config
self.plugin_config = plugin_config
@ -92,8 +94,10 @@ class BaseConsumer:
if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
self.buffer = []
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_eval_statistics_{i}")
self.buffer = []
self.recv_cnt = 0
def state_dict(self) -> Dict[str, torch.Tensor]:
@ -110,6 +114,24 @@ class BaseConsumer:
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
for step in pbar:
i = 0
if self.eval_interval > 0 and step % self.eval_interval == 0:
eval_statistics = None
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
local_eval_result = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_eval_statistics_{r}"
)
if eval_statistics is None:
eval_statistics = local_eval_result
else:
eval_statistics = {
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
}
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
if dist.get_rank() == 0:
if hasattr(self, "wandb_run") and hasattr(self, "global_step"):
self.wandb_run.log(eval_statistics, step=self.global_step)
print(f"Eval statistics: {eval_statistics}")
for _ in range(self.num_recv_per_update):
# receive data from producers
for r in range(self.num_producers):
@ -192,6 +214,7 @@ class SimpleConsumer(BaseConsumer):
minibatch_size=1,
save_interval: int = 100,
save_dir="./model",
eval_interval: int = -1,
):
super().__init__(
num_producers,
@ -206,6 +229,9 @@ class SimpleConsumer(BaseConsumer):
model_config,
plugin_config,
minibatch_size,
save_interval,
save_dir,
eval_interval,
)
path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)

View File

@ -40,6 +40,7 @@ class GRPOConsumer(BaseConsumer):
project_name=None,
save_interval: int = 100,
save_dir="./model",
eval_interval: int = -1,
):
print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
@ -72,6 +73,7 @@ class GRPOConsumer(BaseConsumer):
minibatch_size,
save_interval=save_interval,
save_dir=save_dir,
eval_interval=eval_interval,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -528,4 +530,5 @@ class GRPOConsumer(BaseConsumer):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
return state_dict

View File

@ -205,7 +205,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config)
self.generate_config = generate_config
self.sample_params = SamplingParams(**generate_config)
self.model_config = model_config
self.tokenizer = tokenizer
self.num_generations = num_generations
@ -219,8 +220,9 @@ class VLLMInferenceBackend(BaseInferenceBackend):
micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
]
sample_params = kwargs.get("sample_params", self.sample_params)
outputs = self.llm.generate(
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
)
out_tokens = []
out_len = []
@ -266,11 +268,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"response_idx": response_idx,
}
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
if "gt_answer" in kwargs:
# repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data

View File

@ -34,7 +34,7 @@ def launch_distributed(
inference_microbatch_size: int,
train_batch_size: int,
train_minibatch_size: int,
dataset_config: Dict[str, Any],
train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
@ -50,6 +50,9 @@ def launch_distributed(
project_name: Optional[str] = None,
save_interval: int = 100,
save_dir: str = "./model",
eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
):
if core_algo not in ALGO_MAP:
@ -60,9 +63,9 @@ def launch_distributed(
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = dataset_config["path"]
dataset_path = train_dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size
@ -74,7 +77,7 @@ def launch_distributed(
num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes,
batch_size=inference_batch_size,
dataset_config=dataset_config,
train_dataset_config=train_dataset_config,
dataloaders_config=dataloaders_config,
model_config=inference_model_config,
generate_config=generate_config,
@ -83,6 +86,10 @@ def launch_distributed(
backend=inference_backend,
num_generations=num_generations,
consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=grpo_config["reward_fn_type"],
eval_save_dir=eval_save_dir,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
@ -111,6 +118,7 @@ def launch_distributed(
project_name=project_name,
save_interval=save_interval,
save_dir=save_dir,
eval_interval=eval_interval,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])

View File

@ -1,9 +1,13 @@
import copy
import os
from typing import Any, Dict, Optional
import ray
import ray.util.collective as cc
import torch
import tqdm
from coati.dataset.loader import RawConversationDataset
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
@ -11,7 +15,12 @@ from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .inference_backend import BACKEND_MAP
from .utils import pre_send
from .utils import pre_send, safe_write_jsonl
try:
from vllm import SamplingParams
except ImportError:
LLM = None
class BaseProducer:
@ -22,7 +31,7 @@ class BaseProducer:
num_consumer_procs: int,
num_episodes: int,
batch_size: int,
dataset_config: Dict[str, Any],
train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
@ -30,6 +39,10 @@ class BaseProducer:
microbatch_size: int = 1,
backend: str = "transformers",
consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
eval_save_dir: str = "./eval",
):
self.producer_idx = producer_idx
self.num_producers = num_producers
@ -40,10 +53,17 @@ class BaseProducer:
assert batch_size % microbatch_size == 0
self.num_microbatches = batch_size // microbatch_size
self.dataset_config = dataset_config
self.train_dataset_config = train_dataset_config
self.model_config = model_config
self.generate_config = generate_config
self.tokenizer_config = tokenizer_config
self.consumer_plugin_config = consumer_plugin_config
self.eval_interval = eval_interval
self.eval_save_dir = eval_save_dir
self.consumer_global_step = 0
if os.path.exists(self.eval_save_dir):
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
# init tokenizer
if tokenizer_config is None:
@ -55,13 +75,13 @@ class BaseProducer:
self.tokenizer.padding_side = "left"
# init dataloader
dataset_path = dataset_config.pop("path")
self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config)
self.dataloader = DataLoader(
self.dataset,
train_dataset_path = train_dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
self.train_dataloader = DataLoader(
self.train_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
self.dataset,
self.train_dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=True,
@ -71,6 +91,36 @@ class BaseProducer:
num_workers=4,
drop_last=True,
)
self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None:
self.eval_dataloaders = {}
for eval_task_name in self.eval_dataset_config:
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
eval_dataset = RawConversationDataset(
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
)
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
self.eval_dataloaders[eval_task_name] = DataLoader(
eval_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
eval_dataset,
num_replicas=num_producers,
rank=producer_idx,
shuffle=False,
drop_last=False,
seed=42,
),
)
if evaluation_function_type == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif evaluation_function_type == "boxed":
self.evaluation_function = boxed_math_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
else:
raise ValueError("eval_dataset_config is not defined")
self.device = get_current_device()
# init backend
@ -88,6 +138,7 @@ class BaseProducer:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
else:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_eval_statistics_{self.producer_idx}")
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
raise NotImplementedError
@ -96,29 +147,64 @@ class BaseProducer:
raise NotImplementedError
def loop(self) -> None:
num_update_per_episode = len(self.dataloader) // self.num_microbatches
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches
print(
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
)
for episode in range(self.num_episodes):
self.dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.dataloader):
self.train_dataloader.sampler.set_epoch(episode)
for i, batch in enumerate(self.train_dataloader):
if i >= num_valid_microbatches:
break
if self.eval_interval > 0 and self.eval_dataset_config is not None:
if i % self.eval_interval == 0:
eval_statistics = {}
for eval_task_name in self.eval_dataloaders:
print(
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
)
eval_results = []
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
for eval_batch in tqdm.tqdm(
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
eval_results = eval_results + [
self.evaluation_function(
eval_outputs["input_ids"][m][n],
eval_outputs["gt_answer"][m][n],
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
)
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
]
eval_statistics[eval_task_name][0] += len(
[res for res in eval_results if res["ans_valid"] == 1]
)
eval_statistics[eval_task_name][1] += len(eval_results)
# save eval results
result_file_name = os.path.join(
self.eval_save_dir,
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
)
# delete the file if it exists
safe_write_jsonl(result_file_name, eval_results)
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
ray_broadcast_tensor_dict(
eval_statistics,
src=0,
device=self.device,
group_name=f"sync_eval_statistics_{self.producer_idx}",
)
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor(
[
(
self.model.generate_config["temperature"]
if isinstance(self.model.generate_config.temperature, dict)
else self.model.generate_config.temperature
)
]
* outputs["input_ids"].size(0)
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
outputs = pre_send(outputs)
ray_broadcast_tensor_dict(
@ -150,6 +236,8 @@ class BaseProducer:
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
@ -159,7 +247,7 @@ class BaseProducer:
self.model.llm.wake_up()
# linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
if isinstance(self.model.generate_config.temperature, dict):
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
"temperature"
@ -179,7 +267,7 @@ class SimpleProducer(BaseProducer):
num_consumer_procs,
num_episodes,
batch_size,
dataset_config,
train_dataset_config,
dataloaders_config,
model_config,
generate_config,
@ -188,6 +276,10 @@ class SimpleProducer(BaseProducer):
backend="transformers",
num_generations: int = 8,
consumer_plugin_config=None,
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
eval_save_dir: str = "./eval",
):
super().__init__(
producer_idx,
@ -195,7 +287,7 @@ class SimpleProducer(BaseProducer):
num_consumer_procs,
num_episodes,
batch_size,
dataset_config,
train_dataset_config,
dataloaders_config,
model_config,
generate_config,
@ -203,14 +295,21 @@ class SimpleProducer(BaseProducer):
microbatch_size,
backend,
consumer_plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=evaluation_function_type,
eval_save_dir=eval_save_dir,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 1:
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
# if self.producer_idx == 1:
# print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
return rollouts

View File

@ -5,6 +5,7 @@ from .reward_utils import extract_boxed_solution, extract_solution, validate_res
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
eval_mode = kwargs.get("eval_mode", False)
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
acc_score = 10.0
reward = torch.tensor(0.0)
@ -44,36 +45,23 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
reward = reward + length_reward
if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
def gsm8k_reward_fn(input_ids, **kwargs):
gt_answer = kwargs["gt_answer"]
tokenizer = kwargs["tokenizer"]
s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
is_valid = True
try:
int(final_answer.strip())
except Exception:
is_valid = False
format_valid = validate_response_structure(processed_str, kwargs["tags"])
if not is_valid or not format_valid:
return reward
else:
reward += 1.0
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
reward = reward + 9.0
return reward
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
return {
"prompt": prompt,
"prediction": decoded_final_answer,
"gold": gt_answer,
"parsed": final_answer,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
}
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
eval_mode = kwargs.get("eval_mode", False)
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
format_score = 0.0
acc_score = 10.0
@ -91,7 +79,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None:
return reward
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
@ -108,5 +96,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
reward += acc_score
reward = reward + length_reward
if not eval_mode:
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
else:
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
return {
"prompt": prompt,
"prediction": decoded_final_answer,
"gold": gt_answer,
"parsed": final_answer,
"format_valid": format_acc.item(),
"ans_valid": ans_acc.item(),
}

View File

@ -1,6 +1,9 @@
import json
import os
from typing import Any, Dict, List
import torch
from filelock import FileLock
from colossalai.shardformer.layer.loss import dist_log_prob
@ -130,3 +133,13 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.
"""
tensor = tensor * mask
return tensor.sum(dim=dim)
def safe_write_jsonl(file_path, data):
with FileLock(file_path + ".lock"):
# Ensure file exists
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "a", encoding="utf8") as f:
for entry in data:
json_line = json.dumps(entry, ensure_ascii=False)
f.write(json_line + "\n")

View File

@ -1,4 +1,5 @@
import argparse
import json
import os
import ray
@ -8,7 +9,16 @@ from coati.distributed.launch import launch_distributed
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-d", "--dataset", type=str, default="data_train.jsonl")
parser.add_argument(
"-ed",
"--eval-dataset",
type=str,
default='{"eval task name":"data_eval.jsonl"}',
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
The key is the task name, and the value is the path to the jsonl file",
)
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
@ -94,11 +104,14 @@ if __name__ == "__main__":
choices=["think_answer_tags", "boxed"],
help="Reward type for GRPO.",
)
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
parser.add_argument(
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
)
args = parser.parse_args()
if args.train_minibatch_size is None:
@ -208,7 +221,7 @@ if __name__ == "__main__":
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size,
dataset_config={
train_dataset_config={
"path": args.dataset,
"max_length": args.max_prompt_tokens,
"system_prompt": args.system_prompt,
@ -238,4 +251,10 @@ if __name__ == "__main__":
project_name=args.project,
save_interval=args.save_interval,
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
eval_dataset_config={
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
for k, v in json.loads(args.eval_dataset).items()
},
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
)