mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-03 17:19:51 +00:00
add profiling, implement memory efficient logprob alculation
This commit is contained in:
parent
ff1689b69a
commit
2db255bf15
1
.gitignore
vendored
1
.gitignore
vendored
@ -173,3 +173,4 @@ applications/ColossalChat/stdin
|
|||||||
applications/ColossalChat/*.zip
|
applications/ColossalChat/*.zip
|
||||||
applications/ColossalChat/*.prof
|
applications/ColossalChat/*.prof
|
||||||
applications/ColossalChat/*.png
|
applications/ColossalChat/*.png
|
||||||
|
applications/ColossalChat/profiling_log/
|
||||||
|
@ -200,10 +200,15 @@ class BaseConsumer:
|
|||||||
}
|
}
|
||||||
batch = bind_batch([t[0] for t in batches])
|
batch = bind_batch([t[0] for t in batches])
|
||||||
batch = post_recv(batch)
|
batch = post_recv(batch)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
self.profiler.enter("step")
|
self.profiler.enter("step")
|
||||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||||
total_step += 1
|
total_step += 1
|
||||||
self.profiler.exit("step")
|
self.profiler.exit("step")
|
||||||
|
self.profiler.log(
|
||||||
|
f"step_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
||||||
|
)
|
||||||
self.buffer = self.buffer[
|
self.buffer = self.buffer[
|
||||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||||
]
|
]
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
import wandb
|
import wandb
|
||||||
from coati.distributed.consumer import BaseConsumer
|
from coati.distributed.consumer import BaseConsumer
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
from coati.distributed.utils import calc_action_log_probs
|
from coati.distributed.utils import calc_action_log_probs, memory_efficient_logprob
|
||||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
@ -262,7 +262,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
# Support training with PP.
|
# Support training with PP.
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
torch.cuda.reset_peak_memory_stats()
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
reference_model_outputs = self.booster.execute_pipeline(
|
reference_model_outputs = self.booster.execute_pipeline(
|
||||||
iter(
|
iter(
|
||||||
[
|
[
|
||||||
@ -286,22 +286,32 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
if self.booster.plugin.stage_manager.is_last_stage():
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
# breakpoint()
|
# breakpoint()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
# torch.cuda.empty_cache()
|
||||||
reference_action_log_probs = torch.zeros(
|
# current_memory = torch.cuda.memory_allocated()
|
||||||
(input_ids_forward_micro_batch.size(0), num_action),
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
device=input_ids_forward_micro_batch.device,
|
|
||||||
|
# reference_action_log_probs = calc_action_log_probs(
|
||||||
|
# reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
|
||||||
|
# input_ids_forward_micro_batch,
|
||||||
|
# num_action,
|
||||||
|
# self.plugin.shard_config,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# self.profiler.log(f"reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB")
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
# current_memory = torch.cuda.memory_allocated()
|
||||||
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
|
reference_action_log_probs = memory_efficient_logprob(
|
||||||
|
reference_model_outputs["outputs"]["logits"],
|
||||||
|
input_ids_forward_micro_batch,
|
||||||
|
num_action,
|
||||||
|
shard_config=self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
for i in range(reference_action_log_probs.size(0)):
|
# self.profiler.log(f"me_reference_action_log_probs: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB")
|
||||||
# activation for log_softmax is too large if vocab size and sequence length are large
|
# if torch.allclose(reference_action_log_probs, me_reference_action_log_probs):
|
||||||
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
# self.profiler.log("Memory efficient reference action log probs is same as normal reference action log probs")
|
||||||
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
# else:
|
||||||
reference_action_log_probs[i, :] += calc_action_log_probs(
|
# self.profiler.log("Memory efficient reference action log probs is different from normal reference action log probs")
|
||||||
reference_model_outputs["outputs"]["logits"][i : i + 1]
|
|
||||||
/ self.generate_config["temperature"],
|
|
||||||
input_ids_forward_micro_batch[i : i + 1],
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)[0]
|
|
||||||
# breakpoint()
|
# breakpoint()
|
||||||
# torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
self.profiler.log(
|
self.profiler.log(
|
||||||
@ -310,6 +320,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
else:
|
else:
|
||||||
# Dummy reference logprobs for data iterator.
|
# Dummy reference logprobs for data iterator.
|
||||||
reference_action_log_probs = None
|
reference_action_log_probs = None
|
||||||
|
del reference_model_outputs
|
||||||
else:
|
else:
|
||||||
reference_action_log_probs = None
|
reference_action_log_probs = None
|
||||||
|
|
||||||
@ -328,26 +339,43 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
def _criterion(outputs, inputs):
|
def _criterion(outputs, inputs):
|
||||||
action_logits = outputs.logits
|
action_logits = outputs.logits
|
||||||
action_log_probs = torch.zeros(
|
|
||||||
(inputs["input_ids"].size(0), num_action), device=action_logits.device
|
|
||||||
)
|
|
||||||
# breakpoint()
|
# breakpoint()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
for i in range(action_log_probs.size(0)):
|
|
||||||
# activation for log_softmax is too large if vocab size and sequence length are large
|
|
||||||
# e.g., when using 152064 vocab size with 32K seqence length and a micro batch size of 4 (for pp=4 for example),
|
|
||||||
# this activation sorely takes 152064*32000*4*4/1024/1024/1024=72.5GB
|
|
||||||
action_log_probs[i, :] += calc_action_log_probs(
|
|
||||||
action_logits[i : i + 1] / self.generate_config["temperature"],
|
|
||||||
inputs["input_ids"][i : i + 1],
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)[0]
|
|
||||||
# torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
self.profiler.log(
|
# current_memory = torch.cuda.memory_allocated()
|
||||||
f"action_log_probs_{self.global_step}: peak_memory: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MB"
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
|
# action_log_probs = calc_action_log_probs(
|
||||||
|
# action_logits / self.generate_config["temperature"],
|
||||||
|
# inputs["input_ids"],
|
||||||
|
# num_action,
|
||||||
|
# self.plugin.shard_config,
|
||||||
|
# )
|
||||||
|
# # torch.cuda.empty_cache()
|
||||||
|
# self.profiler.log(
|
||||||
|
# f"action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
|
||||||
|
# )
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
# current_memory = torch.cuda.memory_allocated()
|
||||||
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
|
action_log_probs = memory_efficient_logprob(
|
||||||
|
action_logits,
|
||||||
|
inputs["input_ids"],
|
||||||
|
num_action,
|
||||||
|
shard_config=self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
|
# self.profiler.log(
|
||||||
|
# f"me_action_log_probs_{self.global_step}: peak_memory: {(torch.cuda.max_memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
|
||||||
|
# )
|
||||||
|
# if torch.allclose(action_log_probs, me_action_log_probs):
|
||||||
|
# self.profiler.log("Memory efficient action log probs is same as normal action log probs")
|
||||||
|
# else:
|
||||||
|
# self.profiler.log("Memory efficient action log probs is different from normal action log probs")
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
# breakpoint()
|
# breakpoint()
|
||||||
|
# current_memory = torch.cuda.memory_allocated()
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
# self.profiler.log(
|
||||||
|
# f"released by del outputs: {(torch.cuda.memory_allocated()-current_memory) / 1024 / 1024:.2f}MB"
|
||||||
|
# )
|
||||||
if "reference_action_log_probs" in inputs:
|
if "reference_action_log_probs" in inputs:
|
||||||
per_token_kl = (
|
per_token_kl = (
|
||||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||||
@ -463,7 +491,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if need_update:
|
if need_update:
|
||||||
# breakpoint()
|
# breakpoint()
|
||||||
# torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
@ -448,7 +448,9 @@ class DummyProducer(BaseProducer):
|
|||||||
"input_ids": torch.cat(
|
"input_ids": torch.cat(
|
||||||
[
|
[
|
||||||
torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device),
|
torch.repeat_interleave(input_ids.unsqueeze(1), self.num_generations, dim=1).to(device),
|
||||||
torch.ones(
|
torch.randint(
|
||||||
|
0,
|
||||||
|
self.tokenizer.vocab_size, # assuming vocab_size is available in tokenizer
|
||||||
(input_ids.size(0), self.num_generations, num_new_tokens),
|
(input_ids.size(0), self.num_generations, num_new_tokens),
|
||||||
dtype=input_ids.dtype,
|
dtype=input_ids.dtype,
|
||||||
).to(device),
|
).to(device),
|
||||||
|
@ -146,6 +146,45 @@ def safe_append_to_jsonl_file(file_path, data):
|
|||||||
f.write(json_line + "\n")
|
f.write(json_line + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def memory_efficient_logprob(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
num_action: int,
|
||||||
|
chunk_size: int = 2048,
|
||||||
|
shard_config: Any = None,
|
||||||
|
vocab_size: int = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Calculate action log probs in a memory-efficient way by processing in chunks.
|
||||||
|
Args:
|
||||||
|
logits (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||||
|
inputs (torch.LongTensor): Input sequences.
|
||||||
|
num_action (int): Number of actions.
|
||||||
|
chunk_size (int, optional): Size of each chunk to process. Default is 2048.
|
||||||
|
shard_config: Shard configuration for distributed computation.
|
||||||
|
vocab_size (int, optional): Vocabulary size. Default is None.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Action log probs.
|
||||||
|
"""
|
||||||
|
action_log_probs = torch.zeros((logits.size(0), num_action), device=logits.device, dtype=logits.dtype)
|
||||||
|
context_length = logits.size(1) - num_action
|
||||||
|
for i in range(action_log_probs.size(0)):
|
||||||
|
# loop over each sample in the micro-batch
|
||||||
|
for start in range(context_length, logits.size(1), chunk_size):
|
||||||
|
end = min(start + chunk_size, logits.size(1))
|
||||||
|
# calculate log probs in chunks to save memory
|
||||||
|
log_probs = dist_log_prob(
|
||||||
|
inputs[i : i + 1, start - 1 : end],
|
||||||
|
logits[i : i + 1, start - 1 : end],
|
||||||
|
shard_config,
|
||||||
|
vocab_size,
|
||||||
|
logits.dtype,
|
||||||
|
) # [1, chunk_size, 1]
|
||||||
|
log_probs = log_probs.squeeze(-1)
|
||||||
|
action_log_probs[i, start - context_length : end - context_length] += log_probs[0]
|
||||||
|
return action_log_probs
|
||||||
|
|
||||||
|
|
||||||
class CustomProfiler:
|
class CustomProfiler:
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
# Re-import required libraries due to kernel reset
|
# Re-import required libraries due to kernel reset
|
||||||
import argparse
|
import argparse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import matplotlib.cm as cm
|
import matplotlib.cm as cm
|
||||||
import matplotlib.dates as mdates
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
# Argument parser for command line arguments
|
# Argument parser for command line arguments
|
||||||
@ -26,6 +24,10 @@ for file in files:
|
|||||||
actors = defaultdict(lambda: defaultdict(list))
|
actors = defaultdict(lambda: defaultdict(list))
|
||||||
current_entries = {}
|
current_entries = {}
|
||||||
|
|
||||||
|
# First, collect all timestamps to find the minimum
|
||||||
|
all_timestamps = []
|
||||||
|
parsed_lines = []
|
||||||
|
|
||||||
for line in log_lines:
|
for line in log_lines:
|
||||||
if line.startswith("[Log]"):
|
if line.startswith("[Log]"):
|
||||||
continue
|
continue
|
||||||
@ -34,13 +36,23 @@ for line in log_lines:
|
|||||||
actor = parts[1]
|
actor = parts[1]
|
||||||
action = parts[3]
|
action = parts[3]
|
||||||
func_name = parts[4]
|
func_name = parts[4]
|
||||||
|
parsed_lines.append((timestamp, actor, action, func_name))
|
||||||
|
all_timestamps.append(timestamp)
|
||||||
|
|
||||||
|
if not all_timestamps:
|
||||||
|
raise ValueError("No valid log entries found.")
|
||||||
|
|
||||||
|
min_timestamp = min(all_timestamps)
|
||||||
|
|
||||||
|
for timestamp, actor, action, func_name in parsed_lines:
|
||||||
|
rel_timestamp = timestamp - min_timestamp
|
||||||
key = (actor, func_name)
|
key = (actor, func_name)
|
||||||
if action == "Enter":
|
if action == "Enter":
|
||||||
current_entries[key] = timestamp
|
current_entries[key] = rel_timestamp
|
||||||
elif action == "Exit":
|
elif action == "Exit":
|
||||||
start_time = current_entries.pop(key, None)
|
start_time = current_entries.pop(key, None)
|
||||||
if start_time is not None:
|
if start_time is not None:
|
||||||
actors[actor][func_name].append((start_time, timestamp))
|
actors[actor][func_name].append((start_time, rel_timestamp))
|
||||||
|
|
||||||
# Plotting setup
|
# Plotting setup
|
||||||
fig, ax = plt.subplots(figsize=(12, 6))
|
fig, ax = plt.subplots(figsize=(12, 6))
|
||||||
@ -57,21 +69,23 @@ for idx, (actor, func_dict) in enumerate(actors.items()):
|
|||||||
actor_offsets[actor] = base_offset
|
actor_offsets[actor] = base_offset
|
||||||
color = colors(idx)
|
color = colors(idx)
|
||||||
for j, (func, intervals) in enumerate(func_dict.items()):
|
for j, (func, intervals) in enumerate(func_dict.items()):
|
||||||
|
print(actor, func, intervals)
|
||||||
y_val = base_offset + j * function_spacing
|
y_val = base_offset + j * function_spacing
|
||||||
yticks.append(y_val)
|
yticks.append(y_val)
|
||||||
yticklabels.append(f"{actor}:{func}")
|
yticklabels.append(f"{actor}:{func}")
|
||||||
for start, end in intervals:
|
for start, end in intervals:
|
||||||
|
if end - start < 100:
|
||||||
|
end = start + 100 # Ensure minimum length of 100ms
|
||||||
ax.plot(
|
ax.plot(
|
||||||
[datetime.fromtimestamp(start), datetime.fromtimestamp(end)],
|
[start, end],
|
||||||
[y_val, y_val],
|
[y_val, y_val],
|
||||||
color=color,
|
color=color,
|
||||||
linewidth=4,
|
linewidth=2,
|
||||||
label=actor if j == 0 else "",
|
label=actor if j == 0 else "",
|
||||||
)
|
)
|
||||||
base_offset += len(func_dict) * function_spacing + 1
|
base_offset += len(func_dict) * function_spacing + 1
|
||||||
|
|
||||||
# Formatting
|
# Formatting
|
||||||
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M:%S"))
|
|
||||||
ax.set_yticks(yticks)
|
ax.set_yticks(yticks)
|
||||||
ax.set_yticklabels(yticklabels)
|
ax.set_yticklabels(yticklabels)
|
||||||
ax.set_xlabel("Time")
|
ax.set_xlabel("Time")
|
||||||
|
@ -23,6 +23,17 @@ export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
|
|||||||
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png
|
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_32768_GRPO_profile.png
|
||||||
|
|
||||||
# 16K
|
# 16K
|
||||||
MAX_NEW_TOKENS=$((16384-512))
|
# MAX_NEW_TOKENS=$((16384-512))
|
||||||
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.txt
|
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -imbs 8 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_s_2_ptp_2_16384_GRPO_profile.txt
|
||||||
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_16384_GRPO_profile.png
|
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmb2_pp_2_ptp_2_16384_GRPO_profile.png
|
||||||
|
|
||||||
|
# 8K
|
||||||
|
# MAX_NEW_TOKENS=$((8192-512))
|
||||||
|
# python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO-DUMMY-TEST -ibs 16 -tbs 16 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.txt
|
||||||
|
# python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_ptp_2_8192_GRPO_profile.png
|
||||||
|
|
||||||
|
|
||||||
|
# # 32K
|
||||||
|
MAX_NEW_TOKENS=$((32768-512))
|
||||||
|
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 2 -b vllm -a GRPO-DUMMY-TEST -ibs 32 -tbs 32 -tMbs 2 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tmbs 2 -imbs 1 -p GRPO-Math-Profile -ei -5 -zero 1 -pp 2 -tp 2 -ptp 2 -mnt $MAX_NEW_TOKENS 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_tp_2_ptp_2_32768_GRPO_profile.txt
|
||||||
|
python profile_grpo.py --visualization actor_timelines_ibs_64_tbs_32_tmbs_2_pp_2_tp_2_ptp_2_32768_GRPO_profile.png
|
||||||
|
@ -318,8 +318,10 @@ if __name__ == "__main__":
|
|||||||
plugin_config={
|
plugin_config={
|
||||||
"tp_size": args.tensor_parallel_size,
|
"tp_size": args.tensor_parallel_size,
|
||||||
"pp_size": args.pipeline_parallel_size,
|
"pp_size": args.pipeline_parallel_size,
|
||||||
# "num_layers_per_stage": [12,12,2,2],
|
# "num_layers_per_stage": [18, 10],
|
||||||
"num_layers_per_stage": [20, 8],
|
# "num_layers_per_stage": [16, 11, 1],
|
||||||
|
# "num_layers_per_stage": [24, 4],
|
||||||
|
"num_layers_per_stage": [15, 13],
|
||||||
"microbatch_size": max(
|
"microbatch_size": max(
|
||||||
1, args.train_microbatch_size // args.pipeline_parallel_size
|
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
|
Loading…
Reference in New Issue
Block a user