mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-17 08:51:59 +00:00
fix schedualing for multi-node training
This commit is contained in:
parent
d06042b434
commit
7d658402da
@ -358,9 +358,6 @@ def apply_chat_template_and_mask(
|
|||||||
ignore_idx: int = -100,
|
ignore_idx: int = -100,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
|
||||||
if system_prompt is None:
|
|
||||||
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
|
|
||||||
|
|
||||||
system_element = {
|
system_element = {
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": system_prompt,
|
"content": system_prompt,
|
||||||
@ -375,7 +372,9 @@ def apply_chat_template_and_mask(
|
|||||||
tokens = []
|
tokens = []
|
||||||
assistant_mask = []
|
assistant_mask = []
|
||||||
for i, msg in enumerate(chat):
|
for i, msg in enumerate(chat):
|
||||||
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
|
msg_tokens = tokenizer.apply_chat_template(
|
||||||
|
[system_element, msg] if system_prompt else [msg], tokenize=True, add_generation_prompt=True
|
||||||
|
)
|
||||||
# remove unexpected bos token
|
# remove unexpected bos token
|
||||||
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
|
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
|
||||||
msg_tokens = msg_tokens[1:]
|
msg_tokens = msg_tokens[1:]
|
||||||
|
@ -115,10 +115,10 @@ class BaseConsumer:
|
|||||||
eval_statistics = None
|
eval_statistics = None
|
||||||
eval_global_step = None
|
eval_global_step = None
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
|
|
||||||
local_eval_result = ray_broadcast_tensor_dict(
|
local_eval_result = ray_broadcast_tensor_dict(
|
||||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||||
)
|
)
|
||||||
|
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
|
||||||
assert "consumer_global_step" in local_eval_result
|
assert "consumer_global_step" in local_eval_result
|
||||||
eval_global_step = local_eval_result.pop("consumer_global_step").item()
|
eval_global_step = local_eval_result.pop("consumer_global_step").item()
|
||||||
if eval_statistics is None:
|
if eval_statistics is None:
|
||||||
|
@ -41,6 +41,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir="./model",
|
save_dir="./model",
|
||||||
eval_interval: int = -1,
|
eval_interval: int = -1,
|
||||||
|
response_format_tags: dict = None,
|
||||||
):
|
):
|
||||||
print(f"Using GRPO config: {grpo_config}")
|
print(f"Using GRPO config: {grpo_config}")
|
||||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||||
@ -125,12 +126,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||||
)
|
)
|
||||||
# Initialize verifiable reward.
|
# Initialize verifiable reward.
|
||||||
response_format_tags = {
|
|
||||||
"think_start": {"text": "<think>", "num_occur": 1},
|
|
||||||
"think_end": {"text": "</think>", "num_occur": 1},
|
|
||||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
|
||||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
|
||||||
}
|
|
||||||
reward_model_kwargs = {
|
reward_model_kwargs = {
|
||||||
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
||||||
}
|
}
|
||||||
@ -377,7 +372,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
kl.append(appox_kl.mean())
|
kl.append(appox_kl.mean())
|
||||||
else:
|
else:
|
||||||
per_token_kl = 0.0
|
per_token_kl = 0.0
|
||||||
kl.append(0.0)
|
kl.append(torch.zeros(1, device=action_log_probs.device))
|
||||||
|
|
||||||
loss, _ = self.policy_loss_fn(
|
loss, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
|
@ -53,6 +53,7 @@ def launch_distributed(
|
|||||||
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||||
eval_interval: int = 100,
|
eval_interval: int = 100,
|
||||||
eval_save_dir: Optional[str] = None,
|
eval_save_dir: Optional[str] = None,
|
||||||
|
response_format_tags: Dict[Any] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
@ -65,13 +66,46 @@ def launch_distributed(
|
|||||||
|
|
||||||
dataset_path = train_dataset_config["path"]
|
dataset_path = train_dataset_config["path"]
|
||||||
num_samples = get_jsonl_size_fast(dataset_path)
|
num_samples = get_jsonl_size_fast(dataset_path)
|
||||||
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
|
global_inference_batch_size = inference_batch_size * num_producers
|
||||||
num_update_per_episode = num_samples // global_inference_batch_size
|
num_update_per_episode = num_samples // global_inference_batch_size
|
||||||
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
num_recv_per_update = inference_batch_size // inference_microbatch_size
|
||||||
|
|
||||||
procs = []
|
# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
|
||||||
|
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
|
||||||
|
# this go against the design principle of our implementation, and we need to manually force the schedualing,
|
||||||
|
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
|
||||||
|
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
|
||||||
|
nodes = ray.nodes()
|
||||||
|
node_info = {
|
||||||
|
node["NodeID"]: {
|
||||||
|
"num_gpus": node["Resources"].get("GPU", 0),
|
||||||
|
"address": node["NodeManagerAddress"],
|
||||||
|
} # Default to 0 if no GPUs are available
|
||||||
|
for node in nodes
|
||||||
|
}
|
||||||
|
gpu_to_node_id = []
|
||||||
|
gpu_to_ip_address = []
|
||||||
|
for node_id in node_info:
|
||||||
|
for idx in range(int(node_info[node_id]["num_gpus"])):
|
||||||
|
gpu_to_node_id.append(node_id)
|
||||||
|
gpu_to_ip_address.append(node_info[node_id]["address"])
|
||||||
|
print(node_info)
|
||||||
|
|
||||||
|
producer_procs = []
|
||||||
for i in range(num_producers):
|
for i in range(num_producers):
|
||||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
node_id = gpu_to_node_id[0]
|
||||||
|
producer_ip_address = gpu_to_ip_address[0]
|
||||||
|
for _ in range(num_proc_per_producer):
|
||||||
|
gpu_to_node_id.pop(0)
|
||||||
|
gpu_to_ip_address.pop(0)
|
||||||
|
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
|
||||||
|
producer = SimpleProducer.options(
|
||||||
|
num_gpus=num_proc_per_producer,
|
||||||
|
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||||
|
node_id=node_id,
|
||||||
|
soft=False,
|
||||||
|
),
|
||||||
|
).remote(
|
||||||
producer_idx=i,
|
producer_idx=i,
|
||||||
num_producers=num_producers,
|
num_producers=num_producers,
|
||||||
num_consumer_procs=num_consumer_procs,
|
num_consumer_procs=num_consumer_procs,
|
||||||
@ -91,20 +125,35 @@ def launch_distributed(
|
|||||||
evaluation_function_type=grpo_config["reward_fn_type"],
|
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
producer_procs.append(producer)
|
||||||
|
ray.get([p.setup.remote() for p in producer_procs])
|
||||||
generate_config_consumer = copy.deepcopy(generate_config)
|
generate_config_consumer = copy.deepcopy(generate_config)
|
||||||
generate_config_consumer.update(
|
generate_config_consumer.update(
|
||||||
dict(
|
dict(
|
||||||
backend=inference_backend,
|
backend=inference_backend,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
consumer_master_ip_address = gpu_to_ip_address[0]
|
||||||
|
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
|
||||||
|
consumer_procs = []
|
||||||
for i in range(num_consumer_procs):
|
for i in range(num_consumer_procs):
|
||||||
consumer = core_consumer.options(num_gpus=1).remote(
|
node_id = gpu_to_node_id[0]
|
||||||
|
consumer_ip_address = gpu_to_ip_address[0]
|
||||||
|
gpu_to_node_id.pop(0)
|
||||||
|
gpu_to_ip_address.pop(0)
|
||||||
|
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
|
||||||
|
consumer = core_consumer.options(
|
||||||
|
num_gpus=1,
|
||||||
|
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||||
|
node_id=node_id,
|
||||||
|
soft=False,
|
||||||
|
),
|
||||||
|
).remote(
|
||||||
num_producers=num_producers,
|
num_producers=num_producers,
|
||||||
num_episodes=num_episodes,
|
num_episodes=num_episodes,
|
||||||
rank=i,
|
rank=i,
|
||||||
world_size=num_consumer_procs,
|
world_size=num_consumer_procs,
|
||||||
master_addr=master_addr,
|
master_addr=consumer_master_ip_address,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
num_update_per_episode=num_update_per_episode,
|
num_update_per_episode=num_update_per_episode,
|
||||||
num_recv_per_update=num_recv_per_update,
|
num_recv_per_update=num_recv_per_update,
|
||||||
@ -119,7 +168,8 @@ def launch_distributed(
|
|||||||
save_interval=save_interval,
|
save_interval=save_interval,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval,
|
||||||
|
response_format_tags=response_format_tags,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
consumer_procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in consumer_procs])
|
||||||
ray.get([p.loop.remote() for p in procs])
|
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])
|
||||||
|
@ -112,6 +112,8 @@ class BaseProducer:
|
|||||||
drop_last=False,
|
drop_last=False,
|
||||||
seed=42,
|
seed=42,
|
||||||
),
|
),
|
||||||
|
num_workers=4,
|
||||||
|
drop_last=False,
|
||||||
)
|
)
|
||||||
if evaluation_function_type == "think_answer_tags":
|
if evaluation_function_type == "think_answer_tags":
|
||||||
self.evaluation_function = math_reward_fn
|
self.evaluation_function = math_reward_fn
|
||||||
@ -166,8 +168,8 @@ class BaseProducer:
|
|||||||
)
|
)
|
||||||
eval_results = []
|
eval_results = []
|
||||||
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
|
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
|
||||||
for eval_batch in tqdm.tqdm(
|
for eval_batch_id, eval_batch in tqdm.tqdm(
|
||||||
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
enumerate(self.eval_dataloaders[eval_task_name]), desc=f"Evaluating: {eval_task_name}"
|
||||||
):
|
):
|
||||||
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
||||||
eval_results = eval_results + [
|
eval_results = eval_results + [
|
||||||
@ -190,9 +192,7 @@ class BaseProducer:
|
|||||||
self.eval_save_dir,
|
self.eval_save_dir,
|
||||||
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
|
||||||
)
|
)
|
||||||
# delete the file if it exists
|
|
||||||
safe_write_jsonl(result_file_name, eval_results)
|
safe_write_jsonl(result_file_name, eval_results)
|
||||||
print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
|
|
||||||
eval_statistics["consumer_global_step"] = torch.tensor(
|
eval_statistics["consumer_global_step"] = torch.tensor(
|
||||||
[self.consumer_global_step], device=self.device
|
[self.consumer_global_step], device=self.device
|
||||||
)
|
)
|
||||||
@ -230,6 +230,8 @@ class BaseProducer:
|
|||||||
state_dict = ray_broadcast_tensor_dict(
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||||
)
|
)
|
||||||
|
if "consumer_global_step" in state_dict:
|
||||||
|
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
@ -308,7 +310,10 @@ class SimpleProducer(BaseProducer):
|
|||||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||||
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||||
if self.producer_idx == 1:
|
if self.producer_idx == 1:
|
||||||
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
print(
|
||||||
|
"Truncated rollout example:\n",
|
||||||
|
self.tokenizer.decode(rollouts["input_ids"][0][0][-5000:], skip_special_tokens=True),
|
||||||
|
)
|
||||||
|
|
||||||
return rollouts
|
return rollouts
|
||||||
|
|
||||||
|
@ -14,6 +14,10 @@ def verify_math_representation(completion, gt_answer):
|
|||||||
"""
|
"""
|
||||||
Verify if the completion is a valid math representation of the gt_answer.
|
Verify if the completion is a valid math representation of the gt_answer.
|
||||||
"""
|
"""
|
||||||
|
if not completion.startswith("\\boxed{"):
|
||||||
|
completion = "\\boxed{" + completion + "}"
|
||||||
|
if not gt_answer.startswith("\\boxed{"):
|
||||||
|
gt_answer = "\\boxed{" + gt_answer + "}"
|
||||||
target = (
|
target = (
|
||||||
ExprExtractionConfig(),
|
ExprExtractionConfig(),
|
||||||
LatexExtractionConfig(
|
LatexExtractionConfig(
|
||||||
@ -59,7 +63,7 @@ def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, rew
|
|||||||
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
|
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
|
||||||
# plain text answer cannot be parsed, but is correct
|
# plain text answer cannot be parsed, but is correct
|
||||||
reward += acc_score
|
reward += acc_score
|
||||||
else:
|
elif math_verify_result == CANNOT_PARSE_PREDICTION:
|
||||||
reward += (
|
reward += (
|
||||||
acc_score / 2
|
acc_score / 2
|
||||||
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
||||||
@ -140,9 +144,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||||
|
|
||||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||||
|
|
||||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||||
final_answer = extract_boxed_solution(decoded_final_answer)
|
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||||
format_valid = final_answer is not None
|
format_valid = final_answer is not None
|
||||||
|
if "tags" in kwargs:
|
||||||
|
tags = kwargs["tags"]
|
||||||
|
format_valid = format_valid and all(
|
||||||
|
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
|
||||||
|
)
|
||||||
# Check format accuracy
|
# Check format accuracy
|
||||||
if format_valid:
|
if format_valid:
|
||||||
format_acc += 1
|
format_acc += 1
|
||||||
|
@ -6,6 +6,13 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
from coati.distributed.launch import launch_distributed
|
from coati.distributed.launch import launch_distributed
|
||||||
|
|
||||||
|
DEFAULT_RESPONSE_FORMAT_TAGS = {
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
@ -65,10 +72,22 @@ if __name__ == "__main__":
|
|||||||
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
"--master_address",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Ray master address for multi-node distributed training, Optional",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
|
"--torch_ddp_master_address",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Torch DDP master address for multi-node distributed training, Optional",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--torch_ddp_master_port",
|
||||||
|
type=int,
|
||||||
|
default=29505,
|
||||||
|
help="Torch DDP master port for multi-node distributed training, Optional",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling parameters
|
# Sampling parameters
|
||||||
@ -105,6 +124,9 @@ if __name__ == "__main__":
|
|||||||
help="Reward type for GRPO.",
|
help="Reward type for GRPO.",
|
||||||
)
|
)
|
||||||
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-rft", "--reponse_format-tags", type=str, default=None, help="Optional json string of the response format tag"
|
||||||
|
)
|
||||||
|
|
||||||
# Logging/Checkpointing parameters
|
# Logging/Checkpointing parameters
|
||||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||||
@ -236,17 +258,17 @@ if __name__ == "__main__":
|
|||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
}, # for zero
|
}, # for zero
|
||||||
# plugin_config={
|
# plugin_config={
|
||||||
# "tp_size": 1,
|
# "tp_size": 4,
|
||||||
# "pp_size": 2,
|
# "pp_size": 2,
|
||||||
# "microbatch_size": max(
|
# "microbatch_size": max(
|
||||||
# 1, args.train_microbatch_size // 2
|
# 1, args.train_microbatch_size // 2
|
||||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
# "zero_stage": 0,
|
# "zero_stage": 1,
|
||||||
# "max_norm": 1.0,
|
# "max_norm": 1.0,
|
||||||
# }, # for pp, tp
|
# }, # for pp, tp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr=args.torch_ddp_master_address,
|
||||||
master_port=args.master_port,
|
master_port=args.torch_ddp_master_port,
|
||||||
core_algo=args.algo,
|
core_algo=args.algo,
|
||||||
project_name=args.project,
|
project_name=args.project,
|
||||||
save_interval=args.save_interval,
|
save_interval=args.save_interval,
|
||||||
@ -257,4 +279,9 @@ if __name__ == "__main__":
|
|||||||
},
|
},
|
||||||
eval_interval=args.eval_interval,
|
eval_interval=args.eval_interval,
|
||||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||||
|
response_format_tags=(
|
||||||
|
json.loads(args.response_format_tags)
|
||||||
|
if args.response_format_tags is not None
|
||||||
|
else DEFAULT_RESPONSE_FORMAT_TAGS
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user