mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
Merge pull request #6292 from hpcaitech/grpo-latest-dev-reward-update
[feat] Update reward verification
This commit is contained in:
commit
17928ad84f
1
.gitignore
vendored
1
.gitignore
vendored
@ -165,3 +165,4 @@ applications/ColossalChat/logs
|
|||||||
applications/ColossalChat/tests/logs
|
applications/ColossalChat/tests/logs
|
||||||
applications/ColossalChat/wandb
|
applications/ColossalChat/wandb
|
||||||
applications/ColossalChat/model
|
applications/ColossalChat/model
|
||||||
|
applications/ColossalChat/eval
|
||||||
|
@ -36,6 +36,7 @@ class BaseConsumer:
|
|||||||
minibatch_size: int = 1,
|
minibatch_size: int = 1,
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir: str = "./model",
|
save_dir: str = "./model",
|
||||||
|
eval_interval: int = -1,
|
||||||
):
|
):
|
||||||
self.num_producers = num_producers
|
self.num_producers = num_producers
|
||||||
self.num_episodes = num_episodes
|
self.num_episodes = num_episodes
|
||||||
@ -51,6 +52,7 @@ class BaseConsumer:
|
|||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||||
self.num_microbatches = batch_size // minibatch_size
|
self.num_microbatches = batch_size // minibatch_size
|
||||||
|
self.eval_interval = eval_interval
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.plugin_config = plugin_config
|
self.plugin_config = plugin_config
|
||||||
@ -93,7 +95,6 @@ class BaseConsumer:
|
|||||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
|
|
||||||
self.recv_cnt = 0
|
self.recv_cnt = 0
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, torch.Tensor]:
|
def state_dict(self) -> Dict[str, torch.Tensor]:
|
||||||
@ -110,6 +111,27 @@ class BaseConsumer:
|
|||||||
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
|
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
i = 0
|
i = 0
|
||||||
|
if self.eval_interval > 0 and step % self.eval_interval == 0:
|
||||||
|
eval_statistics = None
|
||||||
|
eval_global_step = None
|
||||||
|
for r in range(self.num_producers):
|
||||||
|
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
|
||||||
|
local_eval_result = ray_broadcast_tensor_dict(
|
||||||
|
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||||
|
)
|
||||||
|
assert "consumer_global_step" in local_eval_result
|
||||||
|
eval_global_step = local_eval_result.pop("consumer_global_step").item()
|
||||||
|
if eval_statistics is None:
|
||||||
|
eval_statistics = local_eval_result
|
||||||
|
else:
|
||||||
|
eval_statistics = {
|
||||||
|
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
||||||
|
}
|
||||||
|
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
if hasattr(self, "wandb_run"):
|
||||||
|
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
||||||
|
print(f"Eval statistics: {eval_statistics}")
|
||||||
for _ in range(self.num_recv_per_update):
|
for _ in range(self.num_recv_per_update):
|
||||||
# receive data from producers
|
# receive data from producers
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
@ -195,6 +217,7 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
minibatch_size=1,
|
minibatch_size=1,
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir="./model",
|
save_dir="./model",
|
||||||
|
eval_interval: int = -1,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
@ -209,6 +232,9 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
model_config,
|
model_config,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
minibatch_size,
|
minibatch_size,
|
||||||
|
save_interval,
|
||||||
|
save_dir,
|
||||||
|
eval_interval,
|
||||||
)
|
)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
@ -40,6 +40,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
project_name=None,
|
project_name=None,
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir="./model",
|
save_dir="./model",
|
||||||
|
eval_interval: int = -1,
|
||||||
):
|
):
|
||||||
print(f"Using GRPO config: {grpo_config}")
|
print(f"Using GRPO config: {grpo_config}")
|
||||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||||
@ -72,6 +73,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
minibatch_size,
|
minibatch_size,
|
||||||
save_interval=save_interval,
|
save_interval=save_interval,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
|
eval_interval=eval_interval,
|
||||||
)
|
)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
@ -528,4 +530,5 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_model._force_wait_all_gather()
|
self.policy_model._force_wait_all_gather()
|
||||||
model = self.policy_model.unwrap()
|
model = self.policy_model.unwrap()
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
@ -205,7 +205,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
generate_config = generate_config.copy()
|
generate_config = generate_config.copy()
|
||||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
generate_config.update({"n": num_generations})
|
generate_config.update({"n": num_generations})
|
||||||
self.generate_config = SamplingParams(**generate_config)
|
self.generate_config = generate_config
|
||||||
|
self.sample_params = SamplingParams(**generate_config)
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.num_generations = num_generations
|
self.num_generations = num_generations
|
||||||
@ -219,8 +220,9 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
micro_batch_input_ids_no_padding = [
|
micro_batch_input_ids_no_padding = [
|
||||||
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
|
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
|
||||||
]
|
]
|
||||||
|
sample_params = kwargs.get("sample_params", self.sample_params)
|
||||||
outputs = self.llm.generate(
|
outputs = self.llm.generate(
|
||||||
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
|
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
|
||||||
)
|
)
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
out_len = []
|
out_len = []
|
||||||
@ -266,11 +268,11 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
"response_idx": response_idx,
|
"response_idx": response_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
|
||||||
|
|
||||||
if "gt_answer" in kwargs:
|
if "gt_answer" in kwargs:
|
||||||
# repeat gt_answer for each prompt.
|
# repeat gt_answer for each prompt.
|
||||||
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
|
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
|
||||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ def launch_distributed(
|
|||||||
inference_microbatch_size: int,
|
inference_microbatch_size: int,
|
||||||
train_batch_size: int,
|
train_batch_size: int,
|
||||||
train_minibatch_size: int,
|
train_minibatch_size: int,
|
||||||
dataset_config: Dict[str, Any],
|
train_dataset_config: Dict[str, Any],
|
||||||
dataloaders_config: Dict[str, Any],
|
dataloaders_config: Dict[str, Any],
|
||||||
inference_model_config: Dict[str, Any],
|
inference_model_config: Dict[str, Any],
|
||||||
generate_config: Dict[str, Any],
|
generate_config: Dict[str, Any],
|
||||||
@ -50,6 +50,9 @@ def launch_distributed(
|
|||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir: str = "./model",
|
save_dir: str = "./model",
|
||||||
|
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||||
|
eval_interval: int = 100,
|
||||||
|
eval_save_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
@ -60,9 +63,9 @@ def launch_distributed(
|
|||||||
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
||||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||||
|
|
||||||
dataset_path = dataset_config["path"]
|
dataset_path = train_dataset_config["path"]
|
||||||
num_samples = get_jsonl_size_fast(dataset_path)
|
num_samples = get_jsonl_size_fast(dataset_path)
|
||||||
global_inference_batch_size = inference_batch_size * num_producers
|
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
|
||||||
num_update_per_episode = num_samples // global_inference_batch_size
|
num_update_per_episode = num_samples // global_inference_batch_size
|
||||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||||
|
|
||||||
@ -74,7 +77,7 @@ def launch_distributed(
|
|||||||
num_consumer_procs=num_consumer_procs,
|
num_consumer_procs=num_consumer_procs,
|
||||||
num_episodes=num_episodes,
|
num_episodes=num_episodes,
|
||||||
batch_size=inference_batch_size,
|
batch_size=inference_batch_size,
|
||||||
dataset_config=dataset_config,
|
train_dataset_config=train_dataset_config,
|
||||||
dataloaders_config=dataloaders_config,
|
dataloaders_config=dataloaders_config,
|
||||||
model_config=inference_model_config,
|
model_config=inference_model_config,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
@ -83,6 +86,10 @@ def launch_distributed(
|
|||||||
backend=inference_backend,
|
backend=inference_backend,
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
consumer_plugin_config=plugin_config,
|
consumer_plugin_config=plugin_config,
|
||||||
|
eval_dataset_config=eval_dataset_config,
|
||||||
|
eval_interval=eval_interval * num_recv_per_update,
|
||||||
|
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||||
|
eval_save_dir=eval_save_dir,
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
procs.append(producer)
|
||||||
generate_config_consumer = copy.deepcopy(generate_config)
|
generate_config_consumer = copy.deepcopy(generate_config)
|
||||||
@ -111,6 +118,7 @@ def launch_distributed(
|
|||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
save_interval=save_interval,
|
save_interval=save_interval,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
|
eval_interval=eval_interval,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in procs])
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
|
import copy
|
||||||
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import ray.util.collective as cc
|
import ray.util.collective as cc
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
from coati.dataset.loader import RawConversationDataset
|
from coati.dataset.loader import RawConversationDataset
|
||||||
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@ -11,7 +15,12 @@ from colossalai.utils import get_current_device
|
|||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
from .inference_backend import BACKEND_MAP
|
from .inference_backend import BACKEND_MAP
|
||||||
from .utils import pre_send
|
from .utils import pre_send, safe_write_jsonl
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm import SamplingParams
|
||||||
|
except ImportError:
|
||||||
|
LLM = None
|
||||||
|
|
||||||
|
|
||||||
class BaseProducer:
|
class BaseProducer:
|
||||||
@ -22,7 +31,7 @@ class BaseProducer:
|
|||||||
num_consumer_procs: int,
|
num_consumer_procs: int,
|
||||||
num_episodes: int,
|
num_episodes: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
dataset_config: Dict[str, Any],
|
train_dataset_config: Dict[str, Any],
|
||||||
dataloaders_config: Dict[str, Any],
|
dataloaders_config: Dict[str, Any],
|
||||||
model_config: Dict[str, Any],
|
model_config: Dict[str, Any],
|
||||||
generate_config: Dict[str, Any],
|
generate_config: Dict[str, Any],
|
||||||
@ -30,6 +39,10 @@ class BaseProducer:
|
|||||||
microbatch_size: int = 1,
|
microbatch_size: int = 1,
|
||||||
backend: str = "transformers",
|
backend: str = "transformers",
|
||||||
consumer_plugin_config: Dict[str, Any] = None,
|
consumer_plugin_config: Dict[str, Any] = None,
|
||||||
|
eval_dataset_config=None,
|
||||||
|
eval_interval=-1, # disable evaluation
|
||||||
|
evaluation_function_type="think_answer_tags",
|
||||||
|
eval_save_dir: str = "./eval",
|
||||||
):
|
):
|
||||||
self.producer_idx = producer_idx
|
self.producer_idx = producer_idx
|
||||||
self.num_producers = num_producers
|
self.num_producers = num_producers
|
||||||
@ -40,10 +53,17 @@ class BaseProducer:
|
|||||||
assert batch_size % microbatch_size == 0
|
assert batch_size % microbatch_size == 0
|
||||||
self.num_microbatches = batch_size // microbatch_size
|
self.num_microbatches = batch_size // microbatch_size
|
||||||
|
|
||||||
self.dataset_config = dataset_config
|
self.train_dataset_config = train_dataset_config
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
self.tokenizer_config = tokenizer_config
|
self.tokenizer_config = tokenizer_config
|
||||||
|
self.consumer_plugin_config = consumer_plugin_config
|
||||||
|
self.eval_interval = eval_interval
|
||||||
|
self.eval_save_dir = eval_save_dir
|
||||||
|
self.consumer_global_step = 0
|
||||||
|
|
||||||
|
if os.path.exists(self.eval_save_dir):
|
||||||
|
raise ValueError(f"Eval save dir {self.eval_save_dir} already exists. Please delete it or change the name.")
|
||||||
|
|
||||||
# init tokenizer
|
# init tokenizer
|
||||||
if tokenizer_config is None:
|
if tokenizer_config is None:
|
||||||
@ -55,13 +75,13 @@ class BaseProducer:
|
|||||||
self.tokenizer.padding_side = "left"
|
self.tokenizer.padding_side = "left"
|
||||||
|
|
||||||
# init dataloader
|
# init dataloader
|
||||||
dataset_path = dataset_config.pop("path")
|
train_dataset_path = train_dataset_config.pop("path")
|
||||||
self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config)
|
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
|
||||||
self.dataloader = DataLoader(
|
self.train_dataloader = DataLoader(
|
||||||
self.dataset,
|
self.train_dataset,
|
||||||
batch_size=microbatch_size,
|
batch_size=microbatch_size,
|
||||||
sampler=DistributedSampler(
|
sampler=DistributedSampler(
|
||||||
self.dataset,
|
self.train_dataset,
|
||||||
num_replicas=num_producers,
|
num_replicas=num_producers,
|
||||||
rank=producer_idx,
|
rank=producer_idx,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
@ -71,6 +91,36 @@ class BaseProducer:
|
|||||||
num_workers=4,
|
num_workers=4,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.eval_dataset_config = eval_dataset_config
|
||||||
|
if self.eval_dataset_config is not None:
|
||||||
|
self.eval_dataloaders = {}
|
||||||
|
for eval_task_name in self.eval_dataset_config:
|
||||||
|
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
|
||||||
|
eval_dataset = RawConversationDataset(
|
||||||
|
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
|
||||||
|
)
|
||||||
|
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
|
||||||
|
self.eval_dataloaders[eval_task_name] = DataLoader(
|
||||||
|
eval_dataset,
|
||||||
|
batch_size=microbatch_size,
|
||||||
|
sampler=DistributedSampler(
|
||||||
|
eval_dataset,
|
||||||
|
num_replicas=num_producers,
|
||||||
|
rank=producer_idx,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
seed=42,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if evaluation_function_type == "think_answer_tags":
|
||||||
|
self.evaluation_function = math_reward_fn
|
||||||
|
elif evaluation_function_type == "boxed":
|
||||||
|
self.evaluation_function = boxed_math_reward_fn
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
||||||
|
else:
|
||||||
|
raise ValueError("eval_dataset_config is not defined")
|
||||||
self.device = get_current_device()
|
self.device = get_current_device()
|
||||||
|
|
||||||
# init backend
|
# init backend
|
||||||
@ -79,7 +129,7 @@ class BaseProducer:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected backend {backend}")
|
raise ValueError(f"Unexpected backend {backend}")
|
||||||
|
|
||||||
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
|
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||||
@ -96,29 +146,67 @@ class BaseProducer:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def loop(self) -> None:
|
def loop(self) -> None:
|
||||||
num_update_per_episode = len(self.dataloader) // self.num_microbatches
|
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
|
||||||
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
|
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.train_dataloader)}"
|
||||||
)
|
)
|
||||||
for episode in range(self.num_episodes):
|
for episode in range(self.num_episodes):
|
||||||
self.dataloader.sampler.set_epoch(episode)
|
self.train_dataloader.sampler.set_epoch(episode)
|
||||||
for i, batch in enumerate(self.dataloader):
|
for i, batch in enumerate(self.train_dataloader):
|
||||||
if i >= num_valid_microbatches:
|
if i >= num_valid_microbatches:
|
||||||
break
|
break
|
||||||
|
if self.eval_interval > 0 and self.eval_dataset_config is not None:
|
||||||
|
if i % self.eval_interval == 0:
|
||||||
|
eval_statistics = {}
|
||||||
|
for eval_task_name in self.eval_dataloaders:
|
||||||
|
print(
|
||||||
|
f"[P{self.producer_idx}] Evaluate episode {episode} step {i} on task {eval_task_name}"
|
||||||
|
)
|
||||||
|
eval_results = []
|
||||||
|
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
|
||||||
|
for eval_batch in tqdm.tqdm(
|
||||||
|
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
||||||
|
):
|
||||||
|
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
||||||
|
eval_results = eval_results + [
|
||||||
|
self.evaluation_function(
|
||||||
|
eval_outputs["input_ids"][m][n],
|
||||||
|
eval_outputs["gt_answer"][m][n],
|
||||||
|
eval_outputs["response_idx"][m][n],
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
eval_mode=True,
|
||||||
|
)
|
||||||
|
for m in range(eval_outputs["input_ids"].size(0))
|
||||||
|
for n in range(eval_outputs["input_ids"].size(1))
|
||||||
|
]
|
||||||
|
eval_statistics[eval_task_name][0] += len(
|
||||||
|
[res for res in eval_results if res["ans_valid"] == 1]
|
||||||
|
)
|
||||||
|
eval_statistics[eval_task_name][1] += len(eval_results)
|
||||||
|
# save eval results
|
||||||
|
result_file_name = os.path.join(
|
||||||
|
self.eval_save_dir,
|
||||||
|
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
||||||
|
)
|
||||||
|
# delete the file if it exists
|
||||||
|
safe_write_jsonl(result_file_name, eval_results)
|
||||||
|
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
|
||||||
|
eval_statistics["consumer_global_step"] = torch.tensor(
|
||||||
|
[self.consumer_global_step], device=self.device
|
||||||
|
)
|
||||||
|
ray_broadcast_tensor_dict(
|
||||||
|
eval_statistics,
|
||||||
|
src=0,
|
||||||
|
device=self.device,
|
||||||
|
group_name=f"sync_data_{self.producer_idx}",
|
||||||
|
)
|
||||||
outputs = self.rollout(**batch)
|
outputs = self.rollout(**batch)
|
||||||
|
|
||||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||||
outputs["temperature"] = torch.tensor(
|
outputs["temperature"] = torch.tensor(
|
||||||
[
|
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||||
(
|
|
||||||
self.model.generate_config["temperature"]
|
|
||||||
if isinstance(self.model.generate_config.temperature, dict)
|
|
||||||
else self.model.generate_config.temperature
|
|
||||||
)
|
|
||||||
]
|
|
||||||
* outputs["input_ids"].size(0)
|
|
||||||
).to(outputs["input_ids"].device)
|
).to(outputs["input_ids"].device)
|
||||||
outputs = pre_send(outputs)
|
outputs = pre_send(outputs)
|
||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict(
|
||||||
@ -150,6 +238,8 @@ class BaseProducer:
|
|||||||
state_dict = ray_broadcast_tensor_dict(
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||||
)
|
)
|
||||||
|
if "consumer_global_step" in state_dict:
|
||||||
|
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -159,15 +249,12 @@ class BaseProducer:
|
|||||||
self.model.llm.wake_up()
|
self.model.llm.wake_up()
|
||||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||||
if episode <= 0:
|
if episode <= 0:
|
||||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||||
if isinstance(self.model.generate_config.temperature, dict):
|
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
"temperature"
|
||||||
"temperature"
|
] + ratio * 0.9
|
||||||
] + ratio * 0.9
|
if hasattr(self.model, "sample_params"):
|
||||||
else:
|
self.model.sample_params.temperature = self.model.generate_config["temperature"]
|
||||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
|
||||||
"temperature"
|
|
||||||
] + ratio * 0.9
|
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
@ -179,7 +266,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
num_consumer_procs,
|
num_consumer_procs,
|
||||||
num_episodes,
|
num_episodes,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_config,
|
train_dataset_config,
|
||||||
dataloaders_config,
|
dataloaders_config,
|
||||||
model_config,
|
model_config,
|
||||||
generate_config,
|
generate_config,
|
||||||
@ -188,6 +275,10 @@ class SimpleProducer(BaseProducer):
|
|||||||
backend="transformers",
|
backend="transformers",
|
||||||
num_generations: int = 8,
|
num_generations: int = 8,
|
||||||
consumer_plugin_config=None,
|
consumer_plugin_config=None,
|
||||||
|
eval_dataset_config=None,
|
||||||
|
eval_interval=-1, # disable evaluation
|
||||||
|
evaluation_function_type="think_answer_tags",
|
||||||
|
eval_save_dir: str = "./eval",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
producer_idx,
|
producer_idx,
|
||||||
@ -195,7 +286,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
num_consumer_procs,
|
num_consumer_procs,
|
||||||
num_episodes,
|
num_episodes,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_config,
|
train_dataset_config,
|
||||||
dataloaders_config,
|
dataloaders_config,
|
||||||
model_config,
|
model_config,
|
||||||
generate_config,
|
generate_config,
|
||||||
@ -203,8 +294,15 @@ class SimpleProducer(BaseProducer):
|
|||||||
microbatch_size,
|
microbatch_size,
|
||||||
backend,
|
backend,
|
||||||
consumer_plugin_config,
|
consumer_plugin_config,
|
||||||
|
eval_dataset_config=eval_dataset_config,
|
||||||
|
eval_interval=eval_interval,
|
||||||
|
evaluation_function_type=evaluation_function_type,
|
||||||
|
eval_save_dir=eval_save_dir,
|
||||||
)
|
)
|
||||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
||||||
|
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||||
|
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||||
|
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||||
|
@ -1,10 +1,74 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from latex2sympy2_extended import NormalizationConfig
|
||||||
|
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
||||||
|
|
||||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||||
|
|
||||||
|
CANNOT_PARSE_GT_ANSWER = -1
|
||||||
|
CANNOT_PARSE_PREDICTION = -2
|
||||||
|
SUCCESS = 1
|
||||||
|
MATCHING_FAIL = 0
|
||||||
|
|
||||||
|
|
||||||
|
def verify_math_representation(completion, gt_answer):
|
||||||
|
"""
|
||||||
|
Verify if the completion is a valid math representation of the gt_answer.
|
||||||
|
"""
|
||||||
|
target = (
|
||||||
|
ExprExtractionConfig(),
|
||||||
|
LatexExtractionConfig(
|
||||||
|
normalization_config=NormalizationConfig(
|
||||||
|
nits=False,
|
||||||
|
malformed_operators=False,
|
||||||
|
basic_latex=True,
|
||||||
|
boxed="all",
|
||||||
|
units=True,
|
||||||
|
),
|
||||||
|
boxed_match_priority=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
|
||||||
|
raise ValueError("gt_answer should be a string, please verify your training data.")
|
||||||
|
if not isinstance(completion, str) or len(completion) == 0:
|
||||||
|
return MATCHING_FAIL
|
||||||
|
try:
|
||||||
|
parsed_gt_answer = parse(gt_answer, extraction_config=target)
|
||||||
|
if len(parsed_gt_answer) == 0:
|
||||||
|
return CANNOT_PARSE_GT_ANSWER
|
||||||
|
parsed_completion = parse(completion, extraction_config=target)
|
||||||
|
if len(parsed_completion) == 0:
|
||||||
|
return CANNOT_PARSE_PREDICTION
|
||||||
|
if verify(parsed_gt_answer, parsed_completion):
|
||||||
|
return SUCCESS
|
||||||
|
else:
|
||||||
|
return MATCHING_FAIL
|
||||||
|
except Exception:
|
||||||
|
return MATCHING_FAIL
|
||||||
|
|
||||||
|
|
||||||
|
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
|
||||||
|
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
|
||||||
|
if math_verify_result == SUCCESS:
|
||||||
|
ans_acc += 1
|
||||||
|
reward += acc_score
|
||||||
|
elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION:
|
||||||
|
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(
|
||||||
|
",", ""
|
||||||
|
) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""):
|
||||||
|
ans_acc += 1
|
||||||
|
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
|
||||||
|
# plain text answer cannot be parsed, but is correct
|
||||||
|
reward += acc_score
|
||||||
|
else:
|
||||||
|
reward += (
|
||||||
|
acc_score / 2
|
||||||
|
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
||||||
|
return reward, ans_acc
|
||||||
|
|
||||||
|
|
||||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
|
eval_mode = kwargs.get("eval_mode", False)
|
||||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||||
acc_score = 10.0
|
acc_score = 10.0
|
||||||
reward = torch.tensor(0.0)
|
reward = torch.tensor(0.0)
|
||||||
@ -34,46 +98,28 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
format_acc += 1
|
format_acc += 1
|
||||||
|
|
||||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
if (
|
if format_valid and final_answer is not None:
|
||||||
format_valid
|
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||||
and final_answer is not None
|
|
||||||
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
|
||||||
):
|
|
||||||
ans_acc += 1
|
|
||||||
reward += acc_score
|
|
||||||
|
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
|
|
||||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
if not eval_mode:
|
||||||
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||||
|
|
||||||
def gsm8k_reward_fn(input_ids, **kwargs):
|
|
||||||
gt_answer = kwargs["gt_answer"]
|
|
||||||
tokenizer = kwargs["tokenizer"]
|
|
||||||
s, e = kwargs["response_start"], kwargs["response_end"]
|
|
||||||
reward = torch.tensor(0.0).to(input_ids.device)
|
|
||||||
if gt_answer is None:
|
|
||||||
return reward
|
|
||||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
|
||||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
|
||||||
is_valid = True
|
|
||||||
try:
|
|
||||||
int(final_answer.strip())
|
|
||||||
except Exception:
|
|
||||||
is_valid = False
|
|
||||||
|
|
||||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
|
||||||
if not is_valid or not format_valid:
|
|
||||||
return reward
|
|
||||||
else:
|
else:
|
||||||
reward += 1.0
|
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
|
||||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
return {
|
||||||
reward = reward + 9.0
|
"prompt": prompt,
|
||||||
return reward
|
"prediction": decoded_final_answer,
|
||||||
|
"gold": gt_answer,
|
||||||
|
"parsed": final_answer,
|
||||||
|
"format_valid": format_acc.item(),
|
||||||
|
"ans_valid": ans_acc.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
|
eval_mode = kwargs.get("eval_mode", False)
|
||||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||||
format_score = 0.0
|
format_score = 0.0
|
||||||
acc_score = 10.0
|
acc_score = 10.0
|
||||||
@ -91,7 +137,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||||
|
|
||||||
if gt_answer is None:
|
if gt_answer is None:
|
||||||
return reward
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||||
|
|
||||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||||
@ -103,10 +149,19 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
reward += format_score
|
reward += format_score
|
||||||
|
|
||||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
|
if format_valid and final_answer is not None:
|
||||||
ans_acc += 1
|
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||||
reward += acc_score
|
|
||||||
|
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
|
if not eval_mode:
|
||||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||||
|
else:
|
||||||
|
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
|
||||||
|
return {
|
||||||
|
"prompt": prompt,
|
||||||
|
"prediction": decoded_final_answer,
|
||||||
|
"gold": gt_answer,
|
||||||
|
"parsed": final_answer,
|
||||||
|
"format_valid": format_acc.item(),
|
||||||
|
"ans_valid": ans_acc.item(),
|
||||||
|
}
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from filelock import FileLock
|
||||||
|
|
||||||
from colossalai.shardformer.layer.loss import dist_log_prob
|
from colossalai.shardformer.layer.loss import dist_log_prob
|
||||||
|
|
||||||
@ -152,3 +155,13 @@ def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.
|
|||||||
"""
|
"""
|
||||||
tensor = tensor * mask
|
tensor = tensor * mask
|
||||||
return tensor.sum(dim=dim)
|
return tensor.sum(dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_write_jsonl(file_path, data):
|
||||||
|
with FileLock(file_path + ".lock"):
|
||||||
|
# Ensure file exists
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
with open(file_path, "a", encoding="utf8") as f:
|
||||||
|
for entry in data:
|
||||||
|
json_line = json.dumps(entry, ensure_ascii=False)
|
||||||
|
f.write(json_line + "\n")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
@ -8,7 +9,16 @@ from coati.distributed.launch import launch_distributed
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
parser.add_argument("-d", "--dataset", type=str, default="data_train.jsonl")
|
||||||
|
parser.add_argument(
|
||||||
|
"-ed",
|
||||||
|
"--eval-dataset",
|
||||||
|
type=str,
|
||||||
|
default='{"eval task name":"data_eval.jsonl"}',
|
||||||
|
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
|
||||||
|
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
|
||||||
|
The key is the task name, and the value is the path to the jsonl file",
|
||||||
|
)
|
||||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||||
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
|
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
|
||||||
|
|
||||||
@ -94,11 +104,14 @@ if __name__ == "__main__":
|
|||||||
choices=["think_answer_tags", "boxed"],
|
choices=["think_answer_tags", "boxed"],
|
||||||
help="Reward type for GRPO.",
|
help="Reward type for GRPO.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
||||||
|
|
||||||
# Logging/Checkpointing parameters
|
# Logging/Checkpointing parameters
|
||||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||||
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
|
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.train_minibatch_size is None:
|
if args.train_minibatch_size is None:
|
||||||
@ -208,7 +221,7 @@ if __name__ == "__main__":
|
|||||||
inference_microbatch_size=args.inference_microbatch_size,
|
inference_microbatch_size=args.inference_microbatch_size,
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
train_minibatch_size=args.train_minibatch_size,
|
train_minibatch_size=args.train_minibatch_size,
|
||||||
dataset_config={
|
train_dataset_config={
|
||||||
"path": args.dataset,
|
"path": args.dataset,
|
||||||
"max_length": args.max_prompt_tokens,
|
"max_length": args.max_prompt_tokens,
|
||||||
"system_prompt": args.system_prompt,
|
"system_prompt": args.system_prompt,
|
||||||
@ -238,4 +251,10 @@ if __name__ == "__main__":
|
|||||||
project_name=args.project,
|
project_name=args.project,
|
||||||
save_interval=args.save_interval,
|
save_interval=args.save_interval,
|
||||||
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
|
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
|
||||||
|
eval_dataset_config={
|
||||||
|
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
|
||||||
|
for k, v in json.loads(args.eval_dataset).items()
|
||||||
|
},
|
||||||
|
eval_interval=args.eval_interval,
|
||||||
|
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user