add profiling, implement memory efficient logprob alculation

This commit is contained in:
YeAnbang 2025-06-18 10:08:22 +00:00
parent ff1689b69a
commit 2db255bf15
8 changed files with 149 additions and 47 deletions

1
.gitignore vendored
View File

@ -173,3 +173,4 @@ applications/ColossalChat/stdin
applications/ColossalChat/*.zip applications/ColossalChat/*.zip
applications/ColossalChat/*.prof applications/ColossalChat/*.prof
applications/ColossalChat/*.png applications/ColossalChat/*.png
applications/ColossalChat/profiling_log/

View File

@ -200,10 +200,15 @@ class BaseConsumer:
} }
batch = bind_batch([t[0] for t in batches]) batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch) batch = post_recv(batch)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
self.profiler.enter("step") self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
total_step += 1 total_step += 1
self.profiler.exit("step") self.profiler.exit("step")
self.profiler.log(
f"step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
)
self.buffer = self.buffer[ self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
] ]

View File

@ -6,7 +6,7 @@ import torch
import wandb import wandb
from coati.distributed.consumer import BaseConsumer from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import calc_action_log_probs from coati.distributed.utils import calc_action_log_probs, memory_efficient_logprob
from coati.trainer.utils import all_reduce_mean, all_reduce_sum from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@ -262,7 +262,7 @@ class GRPOConsumer(BaseConsumer):
# Support training with PP. # Support training with PP.
if self.policy_loss_fn.beta > 0: if self.policy_loss_fn.beta > 0:
with torch.no_grad(): with torch.no_grad():
torch.cuda.reset_peak_memory_stats() # torch.cuda.reset_peak_memory_stats()
reference_model_outputs = self.booster.execute_pipeline( reference_model_outputs = self.booster.execute_pipeline(
iter( iter(
[ [
@ -286,22 +286,32 @@ class GRPOConsumer(BaseConsumer):
if self.booster.plugin.stage_manager.is_last_stage(): if self.booster.plugin.stage_manager.is_last_stage():
# breakpoint() # breakpoint()
torch.cuda.reset_peak_memory_stats() # torch.cuda.empty_cache()
reference_action_log_probs = torch.zeros( # current_memory = torch.cuda.memory_allocated()
(input_ids_forward_micro_batch.size(0), num_action), # torch.cuda.reset_peak_memory_stats()
device=input_ids_forward_micro_batch.device,
# reference_action_log_probs = calc_action_log_probs(
# reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
# input_ids_forward_micro_batch,
# num_action,
# self.plugin.shard_config,
# )
# self.profiler.log(f"reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB")
# torch.cuda.empty_cache()
# current_memory = torch.cuda.memory_allocated()
# torch.cuda.reset_peak_memory_stats()
reference_action_log_probs = memory_efficient_logprob(
reference_model_outputs["outputs"]["logits"],
input_ids_forward_micro_batch,
num_action,
shard_config=self.plugin.shard_config,
) )
for i in range(reference_action_log_probs.size(0)): # self.profiler.log(f"me_reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB")
# activation for log_softmax is too large if vocab size and sequence length are large # if torch.allclose(reference_action_log_probs, me_reference_action_log_probs):
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example), # self.profiler.log("Memory efficient reference action log probs is same as normal reference action log probs")
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB # else:
reference_action_log_probs[i, :] += calc_action_log_probs( # self.profiler.log("Memory efficient reference action log probs is different from normal reference action log probs")
reference_model_outputs["outputs"]["logits"][i : i + 1]
/ self.generate_config["temperature"],
input_ids_forward_micro_batch[i : i + 1],
num_action,
self.plugin.shard_config,
)[0]
# breakpoint() # breakpoint()
# torch.cuda.empty_cache() # torch.cuda.empty_cache()
self.profiler.log( self.profiler.log(
@ -310,6 +320,7 @@ class GRPOConsumer(BaseConsumer):
else: else:
# Dummy reference logprobs for data iterator. # Dummy reference logprobs for data iterator.
reference_action_log_probs = None reference_action_log_probs = None
del reference_model_outputs
else: else:
reference_action_log_probs = None reference_action_log_probs = None
@ -328,26 +339,43 @@ class GRPOConsumer(BaseConsumer):
def _criterion(outputs, inputs): def _criterion(outputs, inputs):
action_logits = outputs.logits action_logits = outputs.logits
action_log_probs = torch.zeros(
(inputs["input_ids"].size(0), num_action), device=action_logits.device
)
# breakpoint() # breakpoint()
torch.cuda.reset_peak_memory_stats()
for i in range(action_log_probs.size(0)):
# activation for log_softmax is too large if vocab size and sequence length are large
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
action_log_probs[i, :] += calc_action_log_probs(
action_logits[i : i + 1] / self.generate_config["temperature"],
inputs["input_ids"][i : i + 1],
num_action,
self.plugin.shard_config,
)[0]
# torch.cuda.empty_cache() # torch.cuda.empty_cache()
self.profiler.log( # current_memory = torch.cuda.memory_allocated()
f"action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB" # torch.cuda.reset_peak_memory_stats()
# action_log_probs = calc_action_log_probs(
# action_logits / self.generate_config["temperature"],
# inputs["input_ids"],
# num_action,
# self.plugin.shard_config,
# )
# # torch.cuda.empty_cache()
# self.profiler.log(
# f"action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
# )
# torch.cuda.empty_cache()
# current_memory = torch.cuda.memory_allocated()
# torch.cuda.reset_peak_memory_stats()
action_log_probs = memory_efficient_logprob(
action_logits,
inputs["input_ids"],
num_action,
shard_config=self.plugin.shard_config,
) )
# self.profiler.log(
# f"me_action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
# )
# if torch.allclose(action_log_probs, me_action_log_probs):
# self.profiler.log("Memory efficient action log probs is same as normal action log probs")
# else:
# self.profiler.log("Memory efficient action log probs is different from normal action log probs")
# torch.cuda.empty_cache()
# breakpoint() # breakpoint()
# current_memory = torch.cuda.memory_allocated()
# torch.cuda.empty_cache()
# self.profiler.log(
# f"released by del outputs: {(torch.cuda.memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
# )
if "reference_action_log_probs" in inputs: if "reference_action_log_probs" in inputs:
per_token_kl = ( per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs) torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
@ -463,7 +491,7 @@ class GRPOConsumer(BaseConsumer):
if need_update: if need_update:
# breakpoint() # breakpoint()
# torch.cuda.empty_cache() # torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats() # torch.cuda.reset_peak_memory_stats()
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.global_step += 1 self.global_step += 1

View File

@ -448,7 +448,9 @@ class DummyProducer(BaseProducer):
"input_ids": torch.cat( "input_ids": torch.cat(
[ [
torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device), torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device),
torch.ones( torch.randint(
0,
self.tokenizer.vocab_size, # assuming vocab_size is available in tokenizer
(input_ids.size(0), self.num_generations, num_new_tokens), (input_ids.size(0), self.num_generations, num_new_tokens),
dtype=input_ids.dtype, dtype=input_ids.dtype,
).to(device), ).to(device),

View File

@ -146,6 +146,45 @@ def safe_append_to_jsonl_file(file_path, data):
f.write(json_line + "\n") f.write(json_line + "\n")
def memory_efficient_logprob(
logits: torch.Tensor,
inputs: torch.Tensor,
num_action: int,
chunk_size: int = 2048,
shard_config: Any = None,
vocab_size: int = None,
) -> torch.Tensor:
"""
Calculate action log probs in a memory-efficient way by processing in chunks.
Args:
logits (torch.Tensor): Output tensor of Actor.forward.logits.
inputs (torch.LongTensor): Input sequences.
num_action (int): Number of actions.
chunk_size (int, optional): Size of each chunk to process. Default is 2048.
shard_config: Shard configuration for distributed computation.
vocab_size (int, optional): Vocabulary size. Default is None.
Returns:
torch.Tensor: Action log probs.
"""
action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype)
context_length = logits.size(1) - num_action
for i in range(action_log_probs.size(0)):
# loop over each sample in the micro-batch
for start in range(context_length, logits.size(1), chunk_size):
end = min(start + chunk_size, logits.size(1))
# calculate log probs in chunks to save memory
log_probs = dist_log_prob(
inputs[i : i + 1, start - 1 : end],
logits[i : i + 1, start - 1 : end],
shard_config,
vocab_size,
logits.dtype,
) # [1, chunk_size, 1]
log_probs = log_probs.squeeze(-1)
action_log_probs[i, start - context_length : end - context_length] += log_probs[0]
return action_log_probs
class CustomProfiler: class CustomProfiler:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name

View File

@ -1,10 +1,8 @@
# Re-import required libraries due to kernel reset # Re-import required libraries due to kernel reset
import argparse import argparse
from collections import defaultdict from collections import defaultdict
from datetime import datetime
import matplotlib.cm as cm import matplotlib.cm as cm
import matplotlib.dates as mdates
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# Argument parser for command line arguments # Argument parser for command line arguments
@ -26,6 +24,10 @@ for file in files:
actors = defaultdict(lambda: defaultdict(list)) actors = defaultdict(lambda: defaultdict(list))
current_entries = {} current_entries = {}
# First, collect all timestamps to find the minimum
all_timestamps = []
parsed_lines = []
for line in log_lines: for line in log_lines:
if line.startswith("[Log]"): if line.startswith("[Log]"):
continue continue
@ -34,13 +36,23 @@ for line in log_lines:
actor = parts[1] actor = parts[1]
action = parts[3] action = parts[3]
func_name = parts[4] func_name = parts[4]
parsed_lines.append((timestamp, actor, action, func_name))
all_timestamps.append(timestamp)
if not all_timestamps:
raise ValueError("No valid log entries found.")
min_timestamp = min(all_timestamps)
for timestamp, actor, action, func_name in parsed_lines:
rel_timestamp = timestamp - min_timestamp
key = (actor, func_name) key = (actor, func_name)
if action == "Enter": if action == "Enter":
current_entries[key] = timestamp current_entries[key] = rel_timestamp
elif action == "Exit": elif action == "Exit":
start_time = current_entries.pop(key, None) start_time = current_entries.pop(key, None)
if start_time is not None: if start_time is not None:
actors[actor][func_name].append((start_time, timestamp)) actors[actor][func_name].append((start_time, rel_timestamp))
# Plotting setup # Plotting setup
fig, ax = plt.subplots(figsize=(12, 6)) fig, ax = plt.subplots(figsize=(12, 6))
@ -57,21 +69,23 @@ for idx, (actor, func_dict) in enumerate(actors.items()):
actor_offsets[actor] = base_offset actor_offsets[actor] = base_offset
color = colors(idx) color = colors(idx)
for j, (func, intervals) in enumerate(func_dict.items()): for j, (func, intervals) in enumerate(func_dict.items()):
print(actor, func, intervals)
y_val = base_offset + j * function_spacing y_val = base_offset + j * function_spacing
yticks.append(y_val) yticks.append(y_val)
yticklabels.append(f"{actor}:{func}") yticklabels.append(f"{actor}:{func}")
for start, end in intervals: for start, end in intervals:
if end - start < 100:
end = start + 100 # Ensure minimum length of 100ms
ax.plot( ax.plot(
[datetime.fromtimestamp(start), datetime.fromtimestamp(end)], [start, end],
[y_val, y_val], [y_val, y_val],
color=color, color=color,
linewidth=4, linewidth=2,
label=actor if j == 0 else "", label=actor if j == 0 else "",
) )
base_offset += len(func_dict) * function_spacing + 1 base_offset += len(func_dict) * function_spacing + 1
# Formatting # Formatting
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
ax.set_yticks(yticks) ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels) ax.set_yticklabels(yticklabels)
ax.set_xlabel("Time") ax.set_xlabel("Time")

View File

@ -23,6 +23,17 @@ export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png # python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png
# 16K # 16K
MAX_NEW_TOKENS=$((16384-512)) # MAX_NEW_TOKENS=$((16384-512))
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.txt # python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_s_2_ptp_2_16384_GRPO_profile.txt
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png # python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmb2_pp_2_ptp_2_16384_GRPO_profile.png
# 8K
# MAX_NEW_TOKENS=$((8192-512))
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO-DUMMY-TEST -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png
# # 32K
MAX_NEW_TOKENS=$((32768-512))
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 32 -tMbs 2 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -imbs 1 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -tp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_tp_2_ptp_2_32768_GRPO_profile.txt
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_tp_2_ptp_2_32768_GRPO_profile.png

View File

@ -318,8 +318,10 @@ if __name__ == "__main__":
plugin_config={ plugin_config={
"tp_size": args.tensor_parallel_size, "tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size, "pp_size": args.pipeline_parallel_size,
# "num_layers_per_stage": [12,12,2,2], # "num_layers_per_stage": [18, 10],
"num_layers_per_stage": [20, 8], # "num_layers_per_stage": [16, 11, 1],
# "num_layers_per_stage": [24, 4],
"num_layers_per_stage": [15, 13],
"microbatch_size": max( "microbatch_size": max(
1, args.train_microbatch_size // args.pipeline_parallel_size 1, args.train_microbatch_size // args.pipeline_parallel_size
), # microbatch size should be set to train_microbatch_size // pp_size ), # microbatch size should be set to train_microbatch_size // pp_size