[feat] add microbatch forwarding (#6251)

* add microbatch forwarding

* fix forward microbatch

* fix producer OOM

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change project name

* fix temperature annealing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address conversation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang 2025-03-28 10:24:58 +08:00 committed by GitHub
parent 489f215ad9
commit 50153005b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 112 additions and 72 deletions

View File

@ -66,12 +66,7 @@ class BaseConsumer:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
plugin_config = dict(
tp_size=1,
pp_size=1,
precision="bf16",
zero_stage=1,
)
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
plugin_config.update(self.plugin_config)

View File

@ -68,6 +68,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0
self.generate_config = generate_config
self.training_config = training_config
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -131,40 +132,16 @@ class GRPOConsumer(BaseConsumer):
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
need_update = (step_idx + 1) % self.num_microbatches == 0
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
ctx = (
nullcontext()
if need_update or self.booster.plugin.zero_stage == 2
else self.booster.no_sync(self.policy_model, self.optimizer)
)
with ctx:
policy_model_logits = self.policy_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(
policy_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
)
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
@ -177,6 +154,11 @@ class GRPOConsumer(BaseConsumer):
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [batch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
loss_mask = (
None
@ -185,35 +167,82 @@ class GRPOConsumer(BaseConsumer):
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
).repeat_interleave(self.num_generations, dim=0)
)
mean_kl, mean_loss = [], []
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
action_mask_forward_micro_batch = action_mask[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
loss_mask_forward_micro_batch = (
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
if loss_mask is not None
else None
)
advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
action_log_probs = calc_action_log_probs(
policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
# [batch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = (reward - reward_mean) / (reward_std + 1e-4)
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask,
loss_mask=loss_mask,
action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch,
)
if not skip_update:
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin)
reward = all_reduce_mean(reward.mean(), self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
# Calculate accumulate value.
mean_kl.append(kl.data)
mean_loss.append(loss.data)
reward = all_reduce_mean(reward.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
# Calculate accumulate value.
self.accum_loss.add_(loss.data)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data)
self.accum_kl.add_(kl.data)
self.accum_format_reward.add_(format_reward.data)
self.accum_acc_reward.add_(acc_reward.data)
self.accum_advantages.add_(advantages.data)

View File

@ -34,6 +34,7 @@ def launch_distributed(
inference_microbatch_size: int,
train_batch_size: int,
train_microbatch_size: int,
train_minibatch_size: int,
dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
@ -99,9 +100,13 @@ def launch_distributed(
batch_size=train_batch_size,
model_config=train_model_config,
plugin_config=plugin_config,
microbatch_size=train_microbatch_size,
microbatch_size=train_minibatch_size,
generate_config=generate_config_consumer,
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
training_config={
"filter_range": [0.05, 9.0],
"lr": 1e-6,
"train_microbatch_size": train_microbatch_size,
},
num_generations=num_generations,
)
procs.append(consumer)

View File

@ -100,6 +100,7 @@ class BaseProducer:
if i >= num_valid_microbatches:
break
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor(
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
@ -116,16 +117,19 @@ class BaseProducer:
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
self.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
# linear annealing for 1 episode, temperature from initial to 0.7
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config.temperature = (
ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
)
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.7
@ray.remote

View File

@ -10,18 +10,30 @@ if __name__ == "__main__":
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-g", "--num-generations", type=int, default=8)
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
parser.add_argument("-b", "--backend", type=str, default="transformers")
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
args = parser.parse_args()
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
assert (
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model)
train_model_config = dict(
path=args.model,
# use_flash_attention_2=True,
# use_cache=False
)
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
if args.backend == "transformers":
@ -31,13 +43,6 @@ if __name__ == "__main__":
torch_dtype=torch.bfloat16,
)
)
train_model_config.update(
dict(
use_flash_attention_2=True,
torch_dtype=torch.bfloat16,
use_cache=False,
)
)
generate_config.update(
dict(
max_length=1024 + 512,
@ -78,15 +83,17 @@ if __name__ == "__main__":
inference_batch_size=args.inference_batch_size,
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size,
train_microbatch_size=args.train_microbatch_size,
dataset_config={"path": args.dataset, "max_length": 300},
dataloaders_config={},
inference_model_config=inference_model_config,
generate_config=generate_config,
num_generations=args.num_generations,
train_model_config=train_model_config,
plugin_config={},
inference_backend=args.backend,
master_addr="localhost",
master_port=29503,
master_port=29505,
core_algo=args.algo,
)