diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py
index 43cf78383..3393e38d2 100755
--- a/applications/ColossalChat/coati/dataset/loader.py
+++ b/applications/ColossalChat/coati/dataset/loader.py
@@ -358,9 +358,6 @@ def apply_chat_template_and_mask(
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
- if system_prompt is None:
- system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, i.e., 123 .\n\n"
-
system_element = {
"role": "system",
"content": system_prompt,
@@ -375,7 +372,9 @@ def apply_chat_template_and_mask(
tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
- msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
+ msg_tokens = tokenizer.apply_chat_template(
+ [system_element, msg] if system_prompt else [msg], tokenize=True, add_generation_prompt=True
+ )
# remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:]
diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py
index a3a0948bf..54ca611de 100644
--- a/applications/ColossalChat/coati/distributed/consumer.py
+++ b/applications/ColossalChat/coati/distributed/consumer.py
@@ -115,10 +115,10 @@ class BaseConsumer:
eval_statistics = None
eval_global_step = None
for r in range(self.num_producers):
- print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
local_eval_result = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
+ print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
assert "consumer_global_step" in local_eval_result
eval_global_step = local_eval_result.pop("consumer_global_step").item()
if eval_statistics is None:
diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
index 0336125d1..b4175fc26 100644
--- a/applications/ColossalChat/coati/distributed/grpo_consumer.py
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -41,6 +41,7 @@ class GRPOConsumer(BaseConsumer):
save_interval: int = 100,
save_dir="./model",
eval_interval: int = -1,
+ response_format_tags: dict = None,
):
print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
@@ -125,12 +126,6 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
- response_format_tags = {
- "think_start": {"text": "", "num_occur": 1},
- "think_end": {"text": "", "num_occur": 1},
- "answer_start": {"text": "", "num_occur": 1},
- "answer_end": {"text": "", "num_occur": 1},
- }
reward_model_kwargs = {
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
}
@@ -377,7 +372,7 @@ class GRPOConsumer(BaseConsumer):
kl.append(appox_kl.mean())
else:
per_token_kl = 0.0
- kl.append(0.0)
+ kl.append(torch.zeros(1, device=action_log_probs.device))
loss, _ = self.policy_loss_fn(
action_log_probs,
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index 400a62928..3b2f4baa9 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -53,6 +53,7 @@ def launch_distributed(
eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
+ response_format_tags: Dict[Any] = None,
):
if core_algo not in ALGO_MAP:
@@ -65,13 +66,46 @@ def launch_distributed(
dataset_path = train_dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path)
- global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
+ global_inference_batch_size = inference_batch_size * num_producers
num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size
- procs = []
+ # Attention: Ray use complex schedualing method that consider various factors including load-balancing.
+ # when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
+ # this go against the design principle of our implementation, and we need to manually force the schedualing,
+ # allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
+ # node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
+ nodes = ray.nodes()
+ node_info = {
+ node["NodeID"]: {
+ "num_gpus": node["Resources"].get("GPU", 0),
+ "address": node["NodeManagerAddress"],
+ } # Default to 0 if no GPUs are available
+ for node in nodes
+ }
+ gpu_to_node_id = []
+ gpu_to_ip_address = []
+ for node_id in node_info:
+ for idx in range(int(node_info[node_id]["num_gpus"])):
+ gpu_to_node_id.append(node_id)
+ gpu_to_ip_address.append(node_info[node_id]["address"])
+ print(node_info)
+
+ producer_procs = []
for i in range(num_producers):
- producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
+ node_id = gpu_to_node_id[0]
+ producer_ip_address = gpu_to_ip_address[0]
+ for _ in range(num_proc_per_producer):
+ gpu_to_node_id.pop(0)
+ gpu_to_ip_address.pop(0)
+ print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
+ producer = SimpleProducer.options(
+ num_gpus=num_proc_per_producer,
+ scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
+ node_id=node_id,
+ soft=False,
+ ),
+ ).remote(
producer_idx=i,
num_producers=num_producers,
num_consumer_procs=num_consumer_procs,
@@ -91,20 +125,35 @@ def launch_distributed(
evaluation_function_type=grpo_config["reward_fn_type"],
eval_save_dir=eval_save_dir,
)
- procs.append(producer)
+ producer_procs.append(producer)
+ ray.get([p.setup.remote() for p in producer_procs])
generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update(
dict(
backend=inference_backend,
)
)
+ consumer_master_ip_address = gpu_to_ip_address[0]
+ print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
+ consumer_procs = []
for i in range(num_consumer_procs):
- consumer = core_consumer.options(num_gpus=1).remote(
+ node_id = gpu_to_node_id[0]
+ consumer_ip_address = gpu_to_ip_address[0]
+ gpu_to_node_id.pop(0)
+ gpu_to_ip_address.pop(0)
+ print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
+ consumer = core_consumer.options(
+ num_gpus=1,
+ scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
+ node_id=node_id,
+ soft=False,
+ ),
+ ).remote(
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,
world_size=num_consumer_procs,
- master_addr=master_addr,
+ master_addr=consumer_master_ip_address,
master_port=master_port,
num_update_per_episode=num_update_per_episode,
num_recv_per_update=num_recv_per_update,
@@ -119,7 +168,8 @@ def launch_distributed(
save_interval=save_interval,
save_dir=save_dir,
eval_interval=eval_interval,
+ response_format_tags=response_format_tags,
)
- procs.append(consumer)
- ray.get([p.setup.remote() for p in procs])
- ray.get([p.loop.remote() for p in procs])
+ consumer_procs.append(consumer)
+ ray.get([p.setup.remote() for p in consumer_procs])
+ ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])
diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py
index de4d60b89..8e7f7d240 100644
--- a/applications/ColossalChat/coati/distributed/producer.py
+++ b/applications/ColossalChat/coati/distributed/producer.py
@@ -112,6 +112,8 @@ class BaseProducer:
drop_last=False,
seed=42,
),
+ num_workers=4,
+ drop_last=False,
)
if evaluation_function_type == "think_answer_tags":
self.evaluation_function = math_reward_fn
@@ -166,8 +168,8 @@ class BaseProducer:
)
eval_results = []
eval_statistics[eval_task_name] = torch.zeros(2, device=self.device)
- for eval_batch in tqdm.tqdm(
- self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
+ for eval_batch_id, eval_batch in tqdm.tqdm(
+ enumerate(self.eval_dataloaders[eval_task_name]), desc=f"Evaluating: {eval_task_name}"
):
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
eval_results = eval_results + [
@@ -190,9 +192,7 @@ class BaseProducer:
self.eval_save_dir,
f"{eval_task_name}_episode_{episode}_step_{self.consumer_global_step}.jsonl",
)
- # delete the file if it exists
safe_write_jsonl(result_file_name, eval_results)
- print(f"[P{self.producer_idx}] Send eval statistics episode {episode} step {i}")
eval_statistics["consumer_global_step"] = torch.tensor(
[self.consumer_global_step], device=self.device
)
@@ -230,6 +230,8 @@ class BaseProducer:
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
+ if "consumer_global_step" in state_dict:
+ self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
print(
@@ -308,7 +310,10 @@ class SimpleProducer(BaseProducer):
def rollout(self, input_ids, attention_mask, **kwargs):
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 1:
- print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
+ print(
+ "Truncated rollout example:\n",
+ self.tokenizer.decode(rollouts["input_ids"][0][0][-5000:], skip_special_tokens=True),
+ )
return rollouts
diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
index 467a9b414..1255124b9 100644
--- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py
+++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py
@@ -14,6 +14,10 @@ def verify_math_representation(completion, gt_answer):
"""
Verify if the completion is a valid math representation of the gt_answer.
"""
+ if not completion.startswith("\\boxed{"):
+ completion = "\\boxed{" + completion + "}"
+ if not gt_answer.startswith("\\boxed{"):
+ gt_answer = "\\boxed{" + gt_answer + "}"
target = (
ExprExtractionConfig(),
LatexExtractionConfig(
@@ -59,7 +63,7 @@ def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, rew
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
# plain text answer cannot be parsed, but is correct
reward += acc_score
- else:
+ elif math_verify_result == CANNOT_PARSE_PREDICTION:
reward += (
acc_score / 2
) # not a valid latex math representation, but the answer is correct, receive half of the score
@@ -140,9 +144,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
+
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer)
format_valid = final_answer is not None
+ if "tags" in kwargs:
+ tags = kwargs["tags"]
+ format_valid = format_valid and all(
+ [decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
+ )
# Check format accuracy
if format_valid:
format_acc += 1
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index b2fafd42e..10e21c942 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -6,6 +6,13 @@ import ray
import torch
from coati.distributed.launch import launch_distributed
+DEFAULT_RESPONSE_FORMAT_TAGS = {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+}
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
@@ -65,10 +72,22 @@ if __name__ == "__main__":
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
)
parser.add_argument(
- "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
+ "--master_address",
+ type=str,
+ default=None,
+ help="Ray master address for multi-node distributed training, Optional",
)
parser.add_argument(
- "--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
+ "--torch_ddp_master_address",
+ type=str,
+ default=None,
+ help="Torch DDP master address for multi-node distributed training, Optional",
+ )
+ parser.add_argument(
+ "--torch_ddp_master_port",
+ type=int,
+ default=29505,
+ help="Torch DDP master port for multi-node distributed training, Optional",
)
# Sampling parameters
@@ -105,6 +124,9 @@ if __name__ == "__main__":
help="Reward type for GRPO.",
)
parser.add_argument("-ei", "--eval-interval", type=int, default=100, help="Interval for evaluation.")
+ parser.add_argument(
+ "-rft", "--reponse_format-tags", type=str, default=None, help="Optional json string of the response format tag"
+ )
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@@ -236,17 +258,17 @@ if __name__ == "__main__":
"zero_stage": 2,
}, # for zero
# plugin_config={
- # "tp_size": 1,
- # "pp_size": 2,
- # "microbatch_size": max(
- # 1, args.train_microbatch_size // 2
- # ), # microbatch size should be set to train_microbatch_size // pp_size
- # "zero_stage": 0,
- # "max_norm": 1.0,
+ # "tp_size": 4,
+ # "pp_size": 2,
+ # "microbatch_size": max(
+ # 1, args.train_microbatch_size // 2
+ # ), # microbatch size should be set to train_microbatch_size // pp_size
+ # "zero_stage": 1,
+ # "max_norm": 1.0,
# }, # for pp, tp
inference_backend=args.backend,
- master_addr="localhost",
- master_port=args.master_port,
+ master_addr=args.torch_ddp_master_address,
+ master_port=args.torch_ddp_master_port,
core_algo=args.algo,
project_name=args.project,
save_interval=args.save_interval,
@@ -257,4 +279,9 @@ if __name__ == "__main__":
},
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
+ response_format_tags=(
+ json.loads(args.response_format_tags)
+ if args.response_format_tags is not None
+ else DEFAULT_RESPONSE_FORMAT_TAGS
+ ),
)