mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-31 07:18:59 +00:00
add profiling, implement memory efficient logprob alculation
This commit is contained in:
parent
ff1689b69a
commit
2db255bf15
1
.gitignore
vendored
1
.gitignore
vendored
@ -173,3 +173,4 @@ applications/ColossalChat/stdin
|
||||
applications/ColossalChat/*.zip
|
||||
applications/ColossalChat/*.prof
|
||||
applications/ColossalChat/*.png
|
||||
applications/ColossalChat/profiling_log/
|
||||
|
@ -200,10 +200,15 @@ class BaseConsumer:
|
||||
}
|
||||
batch = bind_batch([t[0] for t in batches])
|
||||
batch = post_recv(batch)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.profiler.enter("step")
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
total_step += 1
|
||||
self.profiler.exit("step")
|
||||
self.profiler.log(
|
||||
f"step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||
)
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
]
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
import wandb
|
||||
from coati.distributed.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.distributed.utils import calc_action_log_probs, memory_efficient_logprob
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
@ -262,7 +262,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
# Support training with PP.
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
with torch.no_grad():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
reference_model_outputs = self.booster.execute_pipeline(
|
||||
iter(
|
||||
[
|
||||
@ -286,22 +286,32 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
# breakpoint()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
reference_action_log_probs = torch.zeros(
|
||||
(input_ids_forward_micro_batch.size(0), num_action),
|
||||
device=input_ids_forward_micro_batch.device,
|
||||
# torch.cuda.empty_cache()
|
||||
# current_memory = torch.cuda.memory_allocated()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# reference_action_log_probs = calc_action_log_probs(
|
||||
# reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
|
||||
# input_ids_forward_micro_batch,
|
||||
# num_action,
|
||||
# self.plugin.shard_config,
|
||||
# )
|
||||
|
||||
# self.profiler.log(f"reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB")
|
||||
# torch.cuda.empty_cache()
|
||||
# current_memory = torch.cuda.memory_allocated()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
reference_action_log_probs = memory_efficient_logprob(
|
||||
reference_model_outputs["outputs"]["logits"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
shard_config=self.plugin.shard_config,
|
||||
)
|
||||
for i in range(reference_action_log_probs.size(0)):
|
||||
# activation for log_softmax is too large if vocab size and sequence length are large
|
||||
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
||||
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
||||
reference_action_log_probs[i, :] += calc_action_log_probs(
|
||||
reference_model_outputs["outputs"]["logits"][i : i + 1]
|
||||
/ self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch[i : i + 1],
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)[0]
|
||||
# self.profiler.log(f"me_reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB")
|
||||
# if torch.allclose(reference_action_log_probs, me_reference_action_log_probs):
|
||||
# self.profiler.log("Memory efficient reference action log probs is same as normal reference action log probs")
|
||||
# else:
|
||||
# self.profiler.log("Memory efficient reference action log probs is different from normal reference action log probs")
|
||||
# breakpoint()
|
||||
# torch.cuda.empty_cache()
|
||||
self.profiler.log(
|
||||
@ -310,6 +320,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
else:
|
||||
# Dummy reference logprobs for data iterator.
|
||||
reference_action_log_probs = None
|
||||
del reference_model_outputs
|
||||
else:
|
||||
reference_action_log_probs = None
|
||||
|
||||
@ -328,26 +339,43 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
action_log_probs = torch.zeros(
|
||||
(inputs["input_ids"].size(0), num_action), device=action_logits.device
|
||||
)
|
||||
# breakpoint()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
for i in range(action_log_probs.size(0)):
|
||||
# activation for log_softmax is too large if vocab size and sequence length are large
|
||||
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
||||
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
||||
action_log_probs[i, :] += calc_action_log_probs(
|
||||
action_logits[i : i + 1] / self.generate_config["temperature"],
|
||||
inputs["input_ids"][i : i + 1],
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)[0]
|
||||
# torch.cuda.empty_cache()
|
||||
self.profiler.log(
|
||||
f"action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||
# current_memory = torch.cuda.memory_allocated()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
# action_log_probs = calc_action_log_probs(
|
||||
# action_logits / self.generate_config["temperature"],
|
||||
# inputs["input_ids"],
|
||||
# num_action,
|
||||
# self.plugin.shard_config,
|
||||
# )
|
||||
# # torch.cuda.empty_cache()
|
||||
# self.profiler.log(
|
||||
# f"action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
|
||||
# )
|
||||
# torch.cuda.empty_cache()
|
||||
# current_memory = torch.cuda.memory_allocated()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
action_log_probs = memory_efficient_logprob(
|
||||
action_logits,
|
||||
inputs["input_ids"],
|
||||
num_action,
|
||||
shard_config=self.plugin.shard_config,
|
||||
)
|
||||
# self.profiler.log(
|
||||
# f"me_action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
|
||||
# )
|
||||
# if torch.allclose(action_log_probs, me_action_log_probs):
|
||||
# self.profiler.log("Memory efficient action log probs is same as normal action log probs")
|
||||
# else:
|
||||
# self.profiler.log("Memory efficient action log probs is different from normal action log probs")
|
||||
# torch.cuda.empty_cache()
|
||||
# breakpoint()
|
||||
# current_memory = torch.cuda.memory_allocated()
|
||||
# torch.cuda.empty_cache()
|
||||
# self.profiler.log(
|
||||
# f"released by del outputs: {(torch.cuda.memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
|
||||
# )
|
||||
if "reference_action_log_probs" in inputs:
|
||||
per_token_kl = (
|
||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||
@ -463,7 +491,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
if need_update:
|
||||
# breakpoint()
|
||||
# torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
|
@ -448,7 +448,9 @@ class DummyProducer(BaseProducer):
|
||||
"input_ids": torch.cat(
|
||||
[
|
||||
torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device),
|
||||
torch.ones(
|
||||
torch.randint(
|
||||
0,
|
||||
self.tokenizer.vocab_size, # assuming vocab_size is available in tokenizer
|
||||
(input_ids.size(0), self.num_generations, num_new_tokens),
|
||||
dtype=input_ids.dtype,
|
||||
).to(device),
|
||||
|
@ -146,6 +146,45 @@ def safe_append_to_jsonl_file(file_path, data):
|
||||
f.write(json_line + "\n")
|
||||
|
||||
|
||||
def memory_efficient_logprob(
|
||||
logits: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
num_action: int,
|
||||
chunk_size: int = 2048,
|
||||
shard_config: Any = None,
|
||||
vocab_size: int = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate action log probs in a memory-efficient way by processing in chunks.
|
||||
Args:
|
||||
logits (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||
inputs (torch.LongTensor): Input sequences.
|
||||
num_action (int): Number of actions.
|
||||
chunk_size (int, optional): Size of each chunk to process. Default is 2048.
|
||||
shard_config: Shard configuration for distributed computation.
|
||||
vocab_size (int, optional): Vocabulary size. Default is None.
|
||||
Returns:
|
||||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype)
|
||||
context_length = logits.size(1) - num_action
|
||||
for i in range(action_log_probs.size(0)):
|
||||
# loop over each sample in the micro-batch
|
||||
for start in range(context_length, logits.size(1), chunk_size):
|
||||
end = min(start + chunk_size, logits.size(1))
|
||||
# calculate log probs in chunks to save memory
|
||||
log_probs = dist_log_prob(
|
||||
inputs[i : i + 1, start - 1 : end],
|
||||
logits[i : i + 1, start - 1 : end],
|
||||
shard_config,
|
||||
vocab_size,
|
||||
logits.dtype,
|
||||
) # [1, chunk_size, 1]
|
||||
log_probs = log_probs.squeeze(-1)
|
||||
action_log_probs[i, start - context_length : end - context_length] += log_probs[0]
|
||||
return action_log_probs
|
||||
|
||||
|
||||
class CustomProfiler:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
@ -1,10 +1,8 @@
|
||||
# Re-import required libraries due to kernel reset
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.dates as mdates
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Argument parser for command line arguments
|
||||
@ -26,6 +24,10 @@ for file in files:
|
||||
actors = defaultdict(lambda: defaultdict(list))
|
||||
current_entries = {}
|
||||
|
||||
# First, collect all timestamps to find the minimum
|
||||
all_timestamps = []
|
||||
parsed_lines = []
|
||||
|
||||
for line in log_lines:
|
||||
if line.startswith("[Log]"):
|
||||
continue
|
||||
@ -34,13 +36,23 @@ for line in log_lines:
|
||||
actor = parts[1]
|
||||
action = parts[3]
|
||||
func_name = parts[4]
|
||||
parsed_lines.append((timestamp, actor, action, func_name))
|
||||
all_timestamps.append(timestamp)
|
||||
|
||||
if not all_timestamps:
|
||||
raise ValueError("No valid log entries found.")
|
||||
|
||||
min_timestamp = min(all_timestamps)
|
||||
|
||||
for timestamp, actor, action, func_name in parsed_lines:
|
||||
rel_timestamp = timestamp - min_timestamp
|
||||
key = (actor, func_name)
|
||||
if action == "Enter":
|
||||
current_entries[key] = timestamp
|
||||
current_entries[key] = rel_timestamp
|
||||
elif action == "Exit":
|
||||
start_time = current_entries.pop(key, None)
|
||||
if start_time is not None:
|
||||
actors[actor][func_name].append((start_time, timestamp))
|
||||
actors[actor][func_name].append((start_time, rel_timestamp))
|
||||
|
||||
# Plotting setup
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
@ -57,21 +69,23 @@ for idx, (actor, func_dict) in enumerate(actors.items()):
|
||||
actor_offsets[actor] = base_offset
|
||||
color = colors(idx)
|
||||
for j, (func, intervals) in enumerate(func_dict.items()):
|
||||
print(actor, func, intervals)
|
||||
y_val = base_offset + j * function_spacing
|
||||
yticks.append(y_val)
|
||||
yticklabels.append(f"{actor}:{func}")
|
||||
for start, end in intervals:
|
||||
if end - start < 100:
|
||||
end = start + 100 # Ensure minimum length of 100ms
|
||||
ax.plot(
|
||||
[datetime.fromtimestamp(start), datetime.fromtimestamp(end)],
|
||||
[start, end],
|
||||
[y_val, y_val],
|
||||
color=color,
|
||||
linewidth=4,
|
||||
linewidth=2,
|
||||
label=actor if j == 0 else "",
|
||||
)
|
||||
base_offset += len(func_dict) * function_spacing + 1
|
||||
|
||||
# Formatting
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
|
||||
ax.set_yticks(yticks)
|
||||
ax.set_yticklabels(yticklabels)
|
||||
ax.set_xlabel("Time")
|
||||
|
@ -23,6 +23,17 @@ export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
|
||||
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png
|
||||
|
||||
# 16K
|
||||
MAX_NEW_TOKENS=$((16384-512))
|
||||
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.txt
|
||||
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png
|
||||
# MAX_NEW_TOKENS=$((16384-512))
|
||||
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_s_2_ptp_2_16384_GRPO_profile.txt
|
||||
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmb2_pp_2_ptp_2_16384_GRPO_profile.png
|
||||
|
||||
# 8K
|
||||
# MAX_NEW_TOKENS=$((8192-512))
|
||||
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO-DUMMY-TEST -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
|
||||
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png
|
||||
|
||||
|
||||
# # 32K
|
||||
MAX_NEW_TOKENS=$((32768-512))
|
||||
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 32 -tMbs 2 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -imbs 1 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -tp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_tp_2_ptp_2_32768_GRPO_profile.txt
|
||||
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_tp_2_ptp_2_32768_GRPO_profile.png
|
||||
|
@ -318,8 +318,10 @@ if __name__ == "__main__":
|
||||
plugin_config={
|
||||
"tp_size": args.tensor_parallel_size,
|
||||
"pp_size": args.pipeline_parallel_size,
|
||||
# "num_layers_per_stage": [12,12,2,2],
|
||||
"num_layers_per_stage": [20, 8],
|
||||
# "num_layers_per_stage": [18, 10],
|
||||
# "num_layers_per_stage": [16, 11, 1],
|
||||
# "num_layers_per_stage": [24, 4],
|
||||
"num_layers_per_stage": [15, 13],
|
||||
"microbatch_size": max(
|
||||
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
|
Loading…
Reference in New Issue
Block a user