mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-03 18:46:43 +00:00
[feat] add microbatch forwarding (#6251)
* add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
489f215ad9
commit
50153005b4
@ -66,12 +66,7 @@ class BaseConsumer:
|
|||||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||||
|
|
||||||
plugin_config = dict(
|
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||||
tp_size=1,
|
|
||||||
pp_size=1,
|
|
||||||
precision="bf16",
|
|
||||||
zero_stage=1,
|
|
||||||
)
|
|
||||||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
||||||
plugin_config["microbatch_size"] = self.microbatch_size
|
plugin_config["microbatch_size"] = self.microbatch_size
|
||||||
plugin_config.update(self.plugin_config)
|
plugin_config.update(self.plugin_config)
|
||||||
|
@ -68,6 +68,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
|
self.training_config = training_config
|
||||||
|
|
||||||
# Reference model is initialized from policy model.
|
# Reference model is initialized from policy model.
|
||||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
@ -131,40 +132,16 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
num_action = action_mask.shape[1]
|
num_action = action_mask.shape[1]
|
||||||
old_action_log_probs = data["action_log_probs"]
|
old_action_log_probs = data["action_log_probs"]
|
||||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||||
|
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||||
|
|
||||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
|
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
||||||
|
ctx = (
|
||||||
|
nullcontext()
|
||||||
|
if need_update or self.booster.plugin.zero_stage == 2
|
||||||
|
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||||
|
)
|
||||||
with ctx:
|
with ctx:
|
||||||
policy_model_logits = self.policy_model(
|
|
||||||
input_ids=data["input_ids"],
|
|
||||||
attention_mask=data["attention_mask"],
|
|
||||||
)["logits"]
|
|
||||||
action_log_probs = calc_action_log_probs(
|
|
||||||
policy_model_logits / self.generate_config["temperature"],
|
|
||||||
data["input_ids"],
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
reference_model_logits = self.reference_model(
|
|
||||||
input_ids=data["input_ids"],
|
|
||||||
attention_mask=data["attention_mask"],
|
|
||||||
)["logits"]
|
|
||||||
reference_action_log_probs = calc_action_log_probs(
|
|
||||||
reference_model_logits / self.generate_config["temperature"],
|
|
||||||
data["input_ids"],
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
per_token_kl = (
|
|
||||||
torch.exp(reference_action_log_probs - action_log_probs)
|
|
||||||
- (reference_action_log_probs - action_log_probs)
|
|
||||||
- 1
|
|
||||||
)
|
|
||||||
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
|
|
||||||
|
|
||||||
reward_group = self.reward_model(
|
reward_group = self.reward_model(
|
||||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||||
)
|
)
|
||||||
@ -177,6 +154,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
group_reward = reward.view(-1, self.num_generations)
|
group_reward = reward.view(-1, self.num_generations)
|
||||||
reward_mean = group_reward.mean(dim=1)
|
reward_mean = group_reward.mean(dim=1)
|
||||||
|
# [batch_size x num_generations]
|
||||||
|
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||||
|
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||||
|
# [batch_size x num_generations]
|
||||||
|
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||||
loss_mask = (
|
loss_mask = (
|
||||||
None
|
None
|
||||||
@ -185,35 +167,82 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
|
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
|
||||||
).repeat_interleave(self.num_generations, dim=0)
|
).repeat_interleave(self.num_generations, dim=0)
|
||||||
)
|
)
|
||||||
|
mean_kl, mean_loss = [], []
|
||||||
|
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||||
|
input_ids_forward_micro_batch = data["input_ids"][
|
||||||
|
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||||
|
]
|
||||||
|
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||||
|
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||||
|
]
|
||||||
|
action_mask_forward_micro_batch = action_mask[
|
||||||
|
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||||
|
]
|
||||||
|
loss_mask_forward_micro_batch = (
|
||||||
|
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
|
||||||
|
if loss_mask is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
advantages_forward_micro_batch = advantages[
|
||||||
|
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||||
|
]
|
||||||
|
policy_model_logits = self.policy_model(
|
||||||
|
input_ids=input_ids_forward_micro_batch,
|
||||||
|
attention_mask=attention_mask_forward_micro_batch,
|
||||||
|
).logits
|
||||||
|
action_log_probs = calc_action_log_probs(
|
||||||
|
policy_model_logits / self.generate_config["temperature"],
|
||||||
|
input_ids_forward_micro_batch,
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
|
)
|
||||||
|
|
||||||
# [batch_size x num_generations]
|
with torch.no_grad():
|
||||||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
reference_model_logits = self.reference_model(
|
||||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
input_ids=input_ids_forward_micro_batch,
|
||||||
# [batch_size x num_generations]
|
attention_mask=attention_mask_forward_micro_batch,
|
||||||
advantages = (reward - reward_mean) / (reward_std + 1e-4)
|
).logits
|
||||||
|
reference_action_log_probs = calc_action_log_probs(
|
||||||
|
reference_model_logits / self.generate_config["temperature"],
|
||||||
|
input_ids_forward_micro_batch,
|
||||||
|
num_action,
|
||||||
|
self.plugin.shard_config,
|
||||||
|
)
|
||||||
|
|
||||||
loss, skip_update, _ = self.policy_loss_fn(
|
per_token_kl = (
|
||||||
action_log_probs,
|
torch.exp(reference_action_log_probs - action_log_probs)
|
||||||
old_action_log_probs,
|
- (reference_action_log_probs - action_log_probs)
|
||||||
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
|
- 1
|
||||||
per_token_kl,
|
)
|
||||||
action_mask,
|
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||||
loss_mask=loss_mask,
|
action_mask_forward_micro_batch, dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
loss, skip_update, _ = self.policy_loss_fn(
|
||||||
|
action_log_probs,
|
||||||
|
old_action_log_probs,
|
||||||
|
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||||
|
per_token_kl,
|
||||||
|
action_mask_forward_micro_batch,
|
||||||
|
loss_mask=loss_mask_forward_micro_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not skip_update:
|
||||||
|
self.booster.backward(loss, self.optimizer)
|
||||||
|
loss = all_reduce_mean(loss, self.plugin)
|
||||||
|
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||||
|
# Calculate accumulate value.
|
||||||
|
mean_kl.append(kl.data)
|
||||||
|
mean_loss.append(loss.data)
|
||||||
|
|
||||||
if not skip_update:
|
|
||||||
self.booster.backward(loss, self.optimizer)
|
|
||||||
loss = all_reduce_mean(loss, self.plugin)
|
|
||||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
|
||||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||||
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||||
# Calculate accumulate value.
|
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||||
self.accum_loss.add_(loss.data)
|
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||||
self.accum_reward.add_(reward.data)
|
self.accum_reward.add_(reward.data)
|
||||||
self.accum_kl.add_(kl.data)
|
|
||||||
self.accum_format_reward.add_(format_reward.data)
|
self.accum_format_reward.add_(format_reward.data)
|
||||||
self.accum_acc_reward.add_(acc_reward.data)
|
self.accum_acc_reward.add_(acc_reward.data)
|
||||||
self.accum_advantages.add_(advantages.data)
|
self.accum_advantages.add_(advantages.data)
|
||||||
|
@ -34,6 +34,7 @@ def launch_distributed(
|
|||||||
inference_microbatch_size: int,
|
inference_microbatch_size: int,
|
||||||
train_batch_size: int,
|
train_batch_size: int,
|
||||||
train_microbatch_size: int,
|
train_microbatch_size: int,
|
||||||
|
train_minibatch_size: int,
|
||||||
dataset_config: Dict[str, Any],
|
dataset_config: Dict[str, Any],
|
||||||
dataloaders_config: Dict[str, Any],
|
dataloaders_config: Dict[str, Any],
|
||||||
inference_model_config: Dict[str, Any],
|
inference_model_config: Dict[str, Any],
|
||||||
@ -99,9 +100,13 @@ def launch_distributed(
|
|||||||
batch_size=train_batch_size,
|
batch_size=train_batch_size,
|
||||||
model_config=train_model_config,
|
model_config=train_model_config,
|
||||||
plugin_config=plugin_config,
|
plugin_config=plugin_config,
|
||||||
microbatch_size=train_microbatch_size,
|
microbatch_size=train_minibatch_size,
|
||||||
generate_config=generate_config_consumer,
|
generate_config=generate_config_consumer,
|
||||||
training_config={"filter_range": [0.05, 9.0], "lr": 1e-6},
|
training_config={
|
||||||
|
"filter_range": [0.05, 9.0],
|
||||||
|
"lr": 1e-6,
|
||||||
|
"train_microbatch_size": train_microbatch_size,
|
||||||
|
},
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
|
@ -100,6 +100,7 @@ class BaseProducer:
|
|||||||
if i >= num_valid_microbatches:
|
if i >= num_valid_microbatches:
|
||||||
break
|
break
|
||||||
outputs = self.rollout(**batch)
|
outputs = self.rollout(**batch)
|
||||||
|
|
||||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||||
outputs["temperature"] = torch.tensor(
|
outputs["temperature"] = torch.tensor(
|
||||||
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
|
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
|
||||||
@ -116,16 +117,19 @@ class BaseProducer:
|
|||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||||
)
|
)
|
||||||
|
|
||||||
state_dict = ray_broadcast_tensor_dict(
|
state_dict = ray_broadcast_tensor_dict(
|
||||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||||
)
|
)
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
|
del state_dict
|
||||||
|
torch.cuda.empty_cache()
|
||||||
# linear annealing for 1 episode, temperature from initial to 0.7
|
# linear annealing for 1 episode, temperature from initial to 0.7
|
||||||
if episode <= 0:
|
if episode <= 0:
|
||||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
||||||
self.model.generate_config.temperature = (
|
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
||||||
ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7
|
"temperature"
|
||||||
)
|
] + ratio * 0.7
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -10,18 +10,30 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||||
|
parser.add_argument("-g", "--num-generations", type=int, default=8)
|
||||||
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
|
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
|
||||||
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
|
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
|
||||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
|
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
|
||||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
|
parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1)
|
||||||
|
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
|
||||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
|
||||||
|
assert (
|
||||||
|
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
|
||||||
|
and args.train_microbatch_size > 0
|
||||||
|
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||||
|
|
||||||
ray.init(address="local", namespace="ray-example")
|
ray.init(address="local", namespace="ray-example")
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(path=args.model)
|
train_model_config = dict(
|
||||||
|
path=args.model,
|
||||||
|
# use_flash_attention_2=True,
|
||||||
|
# use_cache=False
|
||||||
|
)
|
||||||
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
@ -31,13 +43,6 @@ if __name__ == "__main__":
|
|||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
train_model_config.update(
|
|
||||||
dict(
|
|
||||||
use_flash_attention_2=True,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
use_cache=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_length=1024 + 512,
|
max_length=1024 + 512,
|
||||||
@ -78,15 +83,17 @@ if __name__ == "__main__":
|
|||||||
inference_batch_size=args.inference_batch_size,
|
inference_batch_size=args.inference_batch_size,
|
||||||
inference_microbatch_size=args.inference_microbatch_size,
|
inference_microbatch_size=args.inference_microbatch_size,
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
|
train_minibatch_size=args.train_minibatch_size,
|
||||||
train_microbatch_size=args.train_microbatch_size,
|
train_microbatch_size=args.train_microbatch_size,
|
||||||
dataset_config={"path": args.dataset, "max_length": 300},
|
dataset_config={"path": args.dataset, "max_length": 300},
|
||||||
dataloaders_config={},
|
dataloaders_config={},
|
||||||
inference_model_config=inference_model_config,
|
inference_model_config=inference_model_config,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
|
num_generations=args.num_generations,
|
||||||
train_model_config=train_model_config,
|
train_model_config=train_model_config,
|
||||||
plugin_config={},
|
plugin_config={},
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=29503,
|
master_port=29505,
|
||||||
core_algo=args.algo,
|
core_algo=args.algo,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user