add profiling, implement memory efficient logprob alculation

This commit is contained in:
YeAnbang 2025-06-18 10:08:22 +00:00
parent ff1689b69a
commit 2db255bf15
8 changed files with 149 additions and 47 deletions

1
.gitignore vendored
View File

@ -173,3 +173,4 @@ applications/ColossalChat/stdin
applications/ColossalChat/*.zip
applications/ColossalChat/*.prof
applications/ColossalChat/*.png
applications/ColossalChat/profiling_log/

View File

@ -200,10 +200,15 @@ class BaseConsumer:
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
total_step += 1
self.profiler.exit("step")
self.profiler.log(
f"step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]

View File

@ -6,7 +6,7 @@ import torch
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import calc_action_log_probs
from coati.distributed.utils import calc_action_log_probs, memory_efficient_logprob
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -262,7 +262,7 @@ class GRPOConsumer(BaseConsumer):
# Support training with PP.
if self.policy_loss_fn.beta > 0:
with torch.no_grad():
torch.cuda.reset_peak_memory_stats()
# torch.cuda.reset_peak_memory_stats()
reference_model_outputs = self.booster.execute_pipeline(
iter(
[
@ -286,22 +286,32 @@ class GRPOConsumer(BaseConsumer):
if self.booster.plugin.stage_manager.is_last_stage():
# breakpoint()
torch.cuda.reset_peak_memory_stats()
reference_action_log_probs = torch.zeros(
(input_ids_forward_micro_batch.size(0), num_action),
device=input_ids_forward_micro_batch.device,
# torch.cuda.empty_cache()
# current_memory = torch.cuda.memory_allocated()
# torch.cuda.reset_peak_memory_stats()
# reference_action_log_probs = calc_action_log_probs(
# reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
# input_ids_forward_micro_batch,
# num_action,
# self.plugin.shard_config,
# )
# self.profiler.log(f"reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB")
# torch.cuda.empty_cache()
# current_memory = torch.cuda.memory_allocated()
# torch.cuda.reset_peak_memory_stats()
reference_action_log_probs = memory_efficient_logprob(
reference_model_outputs["outputs"]["logits"],
input_ids_forward_micro_batch,
num_action,
shard_config=self.plugin.shard_config,
)
for i in range(reference_action_log_probs.size(0)):
# activation for log_softmax is too large if vocab size and sequence length are large
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
reference_action_log_probs[i, :] += calc_action_log_probs(
reference_model_outputs["outputs"]["logits"][i : i + 1]
/ self.generate_config["temperature"],
input_ids_forward_micro_batch[i : i + 1],
num_action,
self.plugin.shard_config,
)[0]
# self.profiler.log(f"me_reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB")
# if torch.allclose(reference_action_log_probs, me_reference_action_log_probs):
# self.profiler.log("Memory efficient reference action log probs is same as normal reference action log probs")
# else:
# self.profiler.log("Memory efficient reference action log probs is different from normal reference action log probs")
# breakpoint()
# torch.cuda.empty_cache()
self.profiler.log(
@ -310,6 +320,7 @@ class GRPOConsumer(BaseConsumer):
else:
# Dummy reference logprobs for data iterator.
reference_action_log_probs = None
del reference_model_outputs
else:
reference_action_log_probs = None
@ -328,26 +339,43 @@ class GRPOConsumer(BaseConsumer):
def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = torch.zeros(
(inputs["input_ids"].size(0), num_action), device=action_logits.device
)
# breakpoint()
torch.cuda.reset_peak_memory_stats()
for i in range(action_log_probs.size(0)):
# activation for log_softmax is too large if vocab size and sequence length are large
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
action_log_probs[i, :] += calc_action_log_probs(
action_logits[i : i + 1] / self.generate_config["temperature"],
inputs["input_ids"][i : i + 1],
num_action,
self.plugin.shard_config,
)[0]
# torch.cuda.empty_cache()
self.profiler.log(
f"action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
# current_memory = torch.cuda.memory_allocated()
# torch.cuda.reset_peak_memory_stats()
# action_log_probs = calc_action_log_probs(
# action_logits / self.generate_config["temperature"],
# inputs["input_ids"],
# num_action,
# self.plugin.shard_config,
# )
# # torch.cuda.empty_cache()
# self.profiler.log(
# f"action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
# )
# torch.cuda.empty_cache()
# current_memory = torch.cuda.memory_allocated()
# torch.cuda.reset_peak_memory_stats()
action_log_probs = memory_efficient_logprob(
action_logits,
inputs["input_ids"],
num_action,
shard_config=self.plugin.shard_config,
)
# self.profiler.log(
# f"me_action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
# )
# if torch.allclose(action_log_probs, me_action_log_probs):
# self.profiler.log("Memory efficient action log probs is same as normal action log probs")
# else:
# self.profiler.log("Memory efficient action log probs is different from normal action log probs")
# torch.cuda.empty_cache()
# breakpoint()
# current_memory = torch.cuda.memory_allocated()
# torch.cuda.empty_cache()
# self.profiler.log(
# f"released by del outputs: {(torch.cuda.memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
# )
if "reference_action_log_probs" in inputs:
per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
@ -463,7 +491,7 @@ class GRPOConsumer(BaseConsumer):
if need_update:
# breakpoint()
# torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# torch.cuda.reset_peak_memory_stats()
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1

View File

@ -448,7 +448,9 @@ class DummyProducer(BaseProducer):
"input_ids": torch.cat(
[
torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device),
torch.ones(
torch.randint(
0,
self.tokenizer.vocab_size, # assuming vocab_size is available in tokenizer
(input_ids.size(0), self.num_generations, num_new_tokens),
dtype=input_ids.dtype,
).to(device),

View File

@ -146,6 +146,45 @@ def safe_append_to_jsonl_file(file_path, data):
f.write(json_line + "\n")
def memory_efficient_logprob(
logits: torch.Tensor,
inputs: torch.Tensor,
num_action: int,
chunk_size: int = 2048,
shard_config: Any = None,
vocab_size: int = None,
) -> torch.Tensor:
"""
Calculate action log probs in a memory-efficient way by processing in chunks.
Args:
logits (torch.Tensor): Output tensor of Actor.forward.logits.
inputs (torch.LongTensor): Input sequences.
num_action (int): Number of actions.
chunk_size (int, optional): Size of each chunk to process. Default is 2048.
shard_config: Shard configuration for distributed computation.
vocab_size (int, optional): Vocabulary size. Default is None.
Returns:
torch.Tensor: Action log probs.
"""
action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype)
context_length = logits.size(1) - num_action
for i in range(action_log_probs.size(0)):
# loop over each sample in the micro-batch
for start in range(context_length, logits.size(1), chunk_size):
end = min(start + chunk_size, logits.size(1))
# calculate log probs in chunks to save memory
log_probs = dist_log_prob(
inputs[i : i + 1, start - 1 : end],
logits[i : i + 1, start - 1 : end],
shard_config,
vocab_size,
logits.dtype,
) # [1, chunk_size, 1]
log_probs = log_probs.squeeze(-1)
action_log_probs[i, start - context_length : end - context_length] += log_probs[0]
return action_log_probs
class CustomProfiler:
def __init__(self, name):
self.name = name

View File

@ -1,10 +1,8 @@
# Re-import required libraries due to kernel reset
import argparse
from collections import defaultdict
from datetime import datetime
import matplotlib.cm as cm
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
# Argument parser for command line arguments
@ -26,6 +24,10 @@ for file in files:
actors = defaultdict(lambda: defaultdict(list))
current_entries = {}
# First, collect all timestamps to find the minimum
all_timestamps = []
parsed_lines = []
for line in log_lines:
if line.startswith("[Log]"):
continue
@ -34,13 +36,23 @@ for line in log_lines:
actor = parts[1]
action = parts[3]
func_name = parts[4]
parsed_lines.append((timestamp, actor, action, func_name))
all_timestamps.append(timestamp)
if not all_timestamps:
raise ValueError("No valid log entries found.")
min_timestamp = min(all_timestamps)
for timestamp, actor, action, func_name in parsed_lines:
rel_timestamp = timestamp - min_timestamp
key = (actor, func_name)
if action == "Enter":
current_entries[key] = timestamp
current_entries[key] = rel_timestamp
elif action == "Exit":
start_time = current_entries.pop(key, None)
if start_time is not None:
actors[actor][func_name].append((start_time, timestamp))
actors[actor][func_name].append((start_time, rel_timestamp))
# Plotting setup
fig, ax = plt.subplots(figsize=(12, 6))
@ -57,21 +69,23 @@ for idx, (actor, func_dict) in enumerate(actors.items()):
actor_offsets[actor] = base_offset
color = colors(idx)
for j, (func, intervals) in enumerate(func_dict.items()):
print(actor, func, intervals)
y_val = base_offset + j * function_spacing
yticks.append(y_val)
yticklabels.append(f"{actor}:{func}")
for start, end in intervals:
if end - start < 100:
end = start + 100 # Ensure minimum length of 100ms
ax.plot(
[datetime.fromtimestamp(start), datetime.fromtimestamp(end)],
[start, end],
[y_val, y_val],
color=color,
linewidth=4,
linewidth=2,
label=actor if j == 0 else "",
)
base_offset += len(func_dict) * function_spacing + 1
# Formatting
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels)
ax.set_xlabel("Time")

View File

@ -23,6 +23,17 @@ export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png
# 16K
MAX_NEW_TOKENS=$((16384-512))
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.txt
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png
# MAX_NEW_TOKENS=$((16384-512))
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_s_2_ptp_2_16384_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmb2_pp_2_ptp_2_16384_GRPO_profile.png
# 8K
# MAX_NEW_TOKENS=$((8192-512))
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO-DUMMY-TEST -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png
# # 32K
MAX_NEW_TOKENS=$((32768-512))
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 32 -tMbs 2 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -imbs 1 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -tp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_tp_2_ptp_2_32768_GRPO_profile.txt
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_tp_2_ptp_2_32768_GRPO_profile.png

View File

@ -318,8 +318,10 @@ if __name__ == "__main__":
plugin_config={
"tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size,
# "num_layers_per_stage": [12,12,2,2],
"num_layers_per_stage": [20, 8],
# "num_layers_per_stage": [18, 10],
# "num_layers_per_stage": [16, 11, 1],
# "num_layers_per_stage": [24, 4],
"num_layers_per_stage": [15, 13],
"microbatch_size": max(
1, args.train_microbatch_size // args.pipeline_parallel_size
), # microbatch size should be set to train_microbatch_size // pp_size