mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
address conversation
This commit is contained in:
parent
061d8cb3b6
commit
a40d82f629
@ -39,6 +39,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
use_wandb=True,
|
use_wandb=True,
|
||||||
generate_config=None,
|
generate_config=None,
|
||||||
training_config={},
|
training_config={},
|
||||||
|
project_name=None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
@ -69,6 +70,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
self.training_config = training_config
|
self.training_config = training_config
|
||||||
|
self.project_name = project_name
|
||||||
|
|
||||||
# Reference model is initialized from policy model.
|
# Reference model is initialized from policy model.
|
||||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
@ -111,7 +113,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
):
|
):
|
||||||
# Initialize wandb.
|
# Initialize wandb.
|
||||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||||
self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name)
|
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
|
||||||
|
|
||||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||||
@ -239,6 +241,8 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"source": self.rank,
|
"source": self.rank,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kl = []
|
||||||
|
|
||||||
def _criterion(outputs, inputs):
|
def _criterion(outputs, inputs):
|
||||||
action_logits = outputs.logits
|
action_logits = outputs.logits
|
||||||
action_log_probs = calc_action_log_probs(
|
action_log_probs = calc_action_log_probs(
|
||||||
@ -252,6 +256,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
- (inputs["reference_action_log_probs"] - action_log_probs)
|
- (inputs["reference_action_log_probs"] - action_log_probs)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
|
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
|
||||||
|
inputs["action_mask"], dim=-1
|
||||||
|
)
|
||||||
|
kl.append(appox_kl.mean())
|
||||||
loss, skip_update, _ = self.policy_loss_fn(
|
loss, skip_update, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
@ -273,26 +281,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
loss = policy_model_outputs["loss"]
|
loss = policy_model_outputs["loss"]
|
||||||
|
|
||||||
if self.booster.plugin.stage_manager.is_last_stage():
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
# calculate kl, as we cannot do this inside callback, kl needs be calculate again
|
if len(kl) > 0:
|
||||||
action_logits = policy_model_outputs["outputs"]["logits"]
|
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
|
||||||
action_log_probs = calc_action_log_probs(
|
mean_kl.append(kl)
|
||||||
action_logits / self.generate_config["temperature"],
|
|
||||||
input_ids_forward_micro_batch,
|
|
||||||
num_action,
|
|
||||||
self.plugin.shard_config,
|
|
||||||
)
|
|
||||||
per_token_kl = (
|
|
||||||
torch.exp(reference_action_log_probs - action_log_probs)
|
|
||||||
- (reference_action_log_probs - action_log_probs)
|
|
||||||
- 1
|
|
||||||
)
|
|
||||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
|
||||||
action_mask_forward_micro_batch, dim=-1
|
|
||||||
)
|
|
||||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
|
||||||
loss = all_reduce_mean(loss, self.plugin)
|
loss = all_reduce_mean(loss, self.plugin)
|
||||||
mean_loss.append(loss.data)
|
mean_loss.append(loss.data)
|
||||||
mean_kl.append(kl)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
policy_model_logits = self.policy_model(
|
policy_model_logits = self.policy_model(
|
||||||
|
@ -47,6 +47,7 @@ def launch_distributed(
|
|||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: int = 29500,
|
master_port: int = 29500,
|
||||||
core_algo: str = "GRPO",
|
core_algo: str = "GRPO",
|
||||||
|
project_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if core_algo not in ALGO_MAP:
|
if core_algo not in ALGO_MAP:
|
||||||
@ -108,6 +109,7 @@ def launch_distributed(
|
|||||||
"train_microbatch_size": train_microbatch_size,
|
"train_microbatch_size": train_microbatch_size,
|
||||||
},
|
},
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
|
project_name=project_name,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in procs])
|
||||||
|
@ -11,32 +11,41 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||||
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
||||||
|
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step."
|
"-ibs",
|
||||||
|
"--inference-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-imbs",
|
"-imbs",
|
||||||
"--inference-microbatch-size",
|
"--inference-microbatch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="Number of prompts to send from the producer to the consumer.",
|
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model."
|
"-tbs",
|
||||||
|
"--train-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-tMbs",
|
"-tMbs",
|
||||||
"--train-minibatch-size",
|
"--train-minibatch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.",
|
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-tmbs",
|
"-tmbs",
|
||||||
"--train-microbatch-size",
|
"--train-microbatch-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="Number of samples per device. PP micro batchsize when PP is activated.",
|
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||||
)
|
)
|
||||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||||
@ -119,6 +128,7 @@ if __name__ == "__main__":
|
|||||||
}, # for pp
|
}, # for pp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=29505,
|
master_port=29506,
|
||||||
core_algo=args.algo,
|
core_algo=args.algo,
|
||||||
|
project_name=args.project,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user