mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 04:32:47 +00:00
[feat] Support prompt level dynamic (#6300)
* adjust to dynamic prompt bs * remove debug * update pad seq (#6303) Co-authored-by: Tong Li <tong.li35271158@gmail.com> * adjust to dynamic prompt bs * remove debug * fix dp issue * fix * fix default settings --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
b920af427b
commit
aca547623f
@ -107,9 +107,14 @@ class BaseConsumer:
|
|||||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||||
)
|
)
|
||||||
for episode in range(self.num_episodes):
|
for episode in range(self.num_episodes):
|
||||||
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
|
with tqdm(
|
||||||
|
range(self.num_update_per_episode),
|
||||||
|
desc=f"Episode {episode} with rollout step(s)",
|
||||||
|
disable=self.rank != 0,
|
||||||
|
) as pbar:
|
||||||
for step in pbar:
|
for step in pbar:
|
||||||
i = 0
|
i = 0
|
||||||
|
allow_sync_model = False
|
||||||
for _ in range(self.num_recv_per_update):
|
for _ in range(self.num_recv_per_update):
|
||||||
# receive data from producers
|
# receive data from producers
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
@ -127,15 +132,15 @@ class BaseConsumer:
|
|||||||
]
|
]
|
||||||
batch = bind_batch(batches)
|
batch = bind_batch(batches)
|
||||||
batch = post_recv(batch)
|
batch = post_recv(batch)
|
||||||
loss, num_excessive_prompts = self.step(i, pbar, **batch)
|
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
|
||||||
self.buffer = (
|
|
||||||
self.buffer[
|
if excessive_prompts_idx is not None:
|
||||||
(self.dp_rank + 1) * self.minibatch_size
|
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
|
||||||
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
|
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
|
||||||
]
|
else:
|
||||||
+ self.buffer[self.dp_size * self.minibatch_size :]
|
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||||
)
|
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
|
allow_sync_model = True
|
||||||
pbar.set_postfix({"loss": loss})
|
pbar.set_postfix({"loss": loss})
|
||||||
i += 1
|
i += 1
|
||||||
if self.lr_scheduler is not None:
|
if self.lr_scheduler is not None:
|
||||||
@ -149,29 +154,31 @@ class BaseConsumer:
|
|||||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||||
|
|
||||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||||
if self.pp_size > 1:
|
if allow_sync_model:
|
||||||
print(
|
if self.pp_size > 1:
|
||||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
print(
|
||||||
)
|
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||||
else:
|
|
||||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
state_dict = self.state_dict()
|
|
||||||
if self.pp_size > 1:
|
|
||||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
|
||||||
ray_broadcast_tensor_dict(
|
|
||||||
state_dict,
|
|
||||||
src=self.num_producers,
|
|
||||||
device=self.device,
|
|
||||||
group_name=f"sync_model_{self.pp_rank}",
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.rank == 0:
|
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||||
ray_broadcast_tensor_dict(
|
torch.cuda.empty_cache()
|
||||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
state_dict = self.state_dict()
|
||||||
)
|
if self.pp_size > 1:
|
||||||
del state_dict
|
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||||
torch.cuda.empty_cache()
|
ray_broadcast_tensor_dict(
|
||||||
|
state_dict,
|
||||||
|
src=self.num_producers,
|
||||||
|
device=self.device,
|
||||||
|
group_name=f"sync_model_{self.pp_rank}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.rank == 0:
|
||||||
|
ray_broadcast_tensor_dict(
|
||||||
|
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||||
|
)
|
||||||
|
del state_dict
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
allow_sync_model = False
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import warnings
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -10,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
|
|||||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
from coati.distributed.utils import calc_action_log_probs
|
from coati.distributed.utils import calc_action_log_probs
|
||||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
@ -42,13 +41,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
save_dir="./model",
|
save_dir="./model",
|
||||||
):
|
):
|
||||||
print(f"Using GRPO config: {grpo_config}")
|
print(f"Using GRPO config: {grpo_config}")
|
||||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
|
||||||
if batch_size != minibatch_size:
|
|
||||||
warnings.warn(
|
|
||||||
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
|
|
||||||
UserWarning,
|
|
||||||
)
|
|
||||||
minibatch_size = batch_size
|
|
||||||
if (
|
if (
|
||||||
plugin_config.get("pp_size", 1) > 1
|
plugin_config.get("pp_size", 1) > 1
|
||||||
and "num_microbatches" not in plugin_config
|
and "num_microbatches" not in plugin_config
|
||||||
@ -90,6 +82,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.grpo_config = grpo_config
|
self.grpo_config = grpo_config
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
|
self.effective_prompt_count = 0
|
||||||
self.total_sample_count = 0
|
self.total_sample_count = 0
|
||||||
|
|
||||||
self.policy_loss_fn = PolicyLoss(
|
self.policy_loss_fn = PolicyLoss(
|
||||||
@ -213,70 +206,66 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
group_ans_acc = (
|
group_ans_acc = (
|
||||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||||
)
|
)
|
||||||
|
# [minibatch_size x num_of_generation]
|
||||||
loss_mask = (
|
loss_mask = (
|
||||||
torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||||
if self.filter_range is None
|
if self.filter_range is None
|
||||||
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
|
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
|
||||||
)
|
)
|
||||||
|
|
||||||
# filter out overlength samples
|
# filter out overlength samples
|
||||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||||
loss_mask = torch.logical_and(
|
loss_mask = torch.logical_and(
|
||||||
loss_mask,
|
loss_mask,
|
||||||
action_mask[:, -1] == False,
|
action_mask[:, -1] == False,
|
||||||
)
|
)
|
||||||
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
||||||
|
|
||||||
|
# [minibatch_size] -> calculate the number of effective prompts
|
||||||
|
effective_prompts_mask = prompt_level_mask.any(dim=1)
|
||||||
|
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
|
||||||
|
self.effective_prompt_count += effective_prompts.item()
|
||||||
|
excessive_prompts_idx = None
|
||||||
|
|
||||||
|
mean_kl, mean_loss = [], []
|
||||||
|
|
||||||
|
if self.grpo_config.get("dynamic_batching", True):
|
||||||
|
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||||
|
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
||||||
|
|
||||||
|
if excessive_prompts > 0:
|
||||||
|
excessive_prompts_per_rank = excessive_prompts // self.dp_size
|
||||||
|
# Only count excessive prompts if they are greater than 1 per rank.
|
||||||
|
# TODO: customize excessive prompts calculation.
|
||||||
|
if excessive_prompts_per_rank != 0:
|
||||||
|
# Mask excessive prompts to False
|
||||||
|
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
|
||||||
|
if excessive_prompts_per_rank <= len(true_indices):
|
||||||
|
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
||||||
|
else:
|
||||||
|
excessive_prompts_idx = true_indices
|
||||||
|
effective_prompts_mask[excessive_prompts_idx] = False
|
||||||
|
|
||||||
|
for mask_idx in range(len(effective_prompts_mask)):
|
||||||
|
if effective_prompts_mask[mask_idx] == False:
|
||||||
|
# Update loss mask.
|
||||||
|
loss_mask[mask_idx] = False
|
||||||
|
else:
|
||||||
|
# If dynamic batching is disabled, we need to use all samples for training.
|
||||||
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
|
|
||||||
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
||||||
|
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||||
self.effective_sample_count += effective_samples.item()
|
self.effective_sample_count += effective_samples.item()
|
||||||
self.total_sample_count += total_samples.item()
|
self.total_sample_count += total_samples.item()
|
||||||
|
|
||||||
mean_kl, mean_loss = [], []
|
|
||||||
|
|
||||||
if self.grpo_config.get("dynamic_batching", True):
|
|
||||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
|
||||||
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
|
|
||||||
num_excessive_samples = (
|
|
||||||
int(
|
|
||||||
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
|
|
||||||
/ self.num_generations
|
|
||||||
/ self.dp_size
|
|
||||||
)
|
|
||||||
* self.num_generations
|
|
||||||
)
|
|
||||||
if num_excessive_samples > 0:
|
|
||||||
data = {
|
|
||||||
k: (
|
|
||||||
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
|
|
||||||
if k
|
|
||||||
in [
|
|
||||||
"input_ids",
|
|
||||||
"attention_mask",
|
|
||||||
"action_log_probs",
|
|
||||||
"action_mask",
|
|
||||||
"response_idx",
|
|
||||||
"gt_answer",
|
|
||||||
]
|
|
||||||
else v
|
|
||||||
)
|
|
||||||
for k, v in data.items()
|
|
||||||
}
|
|
||||||
action_mask = action_mask[
|
|
||||||
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
|
|
||||||
]
|
|
||||||
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
|
|
||||||
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
|
|
||||||
else:
|
|
||||||
num_excessive_samples = 0
|
|
||||||
else:
|
|
||||||
# If dynamic batching is disabled, we need to use all samples for training.
|
|
||||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
|
||||||
num_excessive_samples = 0
|
|
||||||
|
|
||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
{
|
{
|
||||||
"Step": self.global_step + 1,
|
"Global Step": self.global_step,
|
||||||
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
|
||||||
|
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -375,7 +364,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
kl.append(appox_kl.mean())
|
kl.append(appox_kl.mean())
|
||||||
else:
|
else:
|
||||||
per_token_kl = 0.0
|
per_token_kl = 0.0
|
||||||
kl.append(0.0)
|
kl.append(torch.tensor(0.0))
|
||||||
|
|
||||||
loss, _ = self.policy_loss_fn(
|
loss, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
@ -479,6 +468,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||||
|
self.effective_prompt_count = 0
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
self.total_sample_count = 0
|
self.total_sample_count = 0
|
||||||
loss_scalar = self.accum_loss.item()
|
loss_scalar = self.accum_loss.item()
|
||||||
@ -495,6 +485,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
||||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||||
|
f"Sample_utilization: {sample_utilization:.4f}",
|
||||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||||
print("\n".join(to_log_msg))
|
print("\n".join(to_log_msg))
|
||||||
metrics = {
|
metrics = {
|
||||||
@ -520,9 +511,15 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_advantages.zero_()
|
self.accum_advantages.zero_()
|
||||||
self.accum_response_length.zero_()
|
self.accum_response_length.zero_()
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
return loss_scalar, num_excessive_samples // self.num_generations
|
|
||||||
|
if excessive_prompts_idx is not None:
|
||||||
|
# All gather excessive prompts index across DP ranks.
|
||||||
|
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
||||||
|
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
||||||
|
|
||||||
|
return loss_scalar, excessive_prompts_idx
|
||||||
else:
|
else:
|
||||||
return None, num_excessive_samples // self.num_generations
|
return None, excessive_prompts_idx
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
self.policy_model._force_wait_all_gather()
|
self.policy_model._force_wait_all_gather()
|
||||||
|
@ -144,3 +144,29 @@ def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
|||||||
else:
|
else:
|
||||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_tensors(local_tensor_list: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Gathers tensors from all processes and concatenates them along the first dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The input tensor to be gathered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The gathered tensor.
|
||||||
|
"""
|
||||||
|
# Gather tensors across DP group
|
||||||
|
if plugin is not None:
|
||||||
|
all_tensor_lists = [None] * plugin.dp_size
|
||||||
|
dist.all_gather_object(all_tensor_lists, local_tensor_list, group=plugin.dp_group)
|
||||||
|
gathered_tensor_list = []
|
||||||
|
for tensors in all_tensor_lists:
|
||||||
|
gathered_tensor_list.extend(tensors)
|
||||||
|
else:
|
||||||
|
all_tensor_lists = [None] * dist.get_world_size()
|
||||||
|
dist.all_gather_object(all_tensor_lists, local_tensor_list)
|
||||||
|
gathered_tensor_list = []
|
||||||
|
for tensors in all_tensor_lists:
|
||||||
|
gathered_tensor_list.extend(tensors)
|
||||||
|
return gathered_tensor_list
|
||||||
|
@ -9,7 +9,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.")
|
||||||
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
|
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
|
||||||
|
|
||||||
# Distributed training parameters
|
# Distributed training parameters
|
||||||
@ -20,7 +20,7 @@ if __name__ == "__main__":
|
|||||||
"-ibs",
|
"-ibs",
|
||||||
"--inference-batch-size",
|
"--inference-batch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=64,
|
||||||
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
|||||||
"-tMbs",
|
"-tMbs",
|
||||||
"--train-minibatch-size",
|
"--train-minibatch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=8,
|
||||||
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -58,7 +58,7 @@ if __name__ == "__main__":
|
|||||||
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
|
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling parameters
|
# Sampling parameters
|
||||||
@ -223,7 +223,7 @@ if __name__ == "__main__":
|
|||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
}, # for zero
|
}, # for zero
|
||||||
# plugin_config={
|
# plugin_config={
|
||||||
# "tp_size": 1,
|
# "tp_size": 2,
|
||||||
# "pp_size": 2,
|
# "pp_size": 2,
|
||||||
# "microbatch_size": max(
|
# "microbatch_size": max(
|
||||||
# 1, args.train_microbatch_size // 2
|
# 1, args.train_microbatch_size // 2
|
||||||
|
Loading…
Reference in New Issue
Block a user