mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-17 11:19:57 +00:00
fix vllm
This commit is contained in:
parent
7795d4c50d
commit
7ee4452f8c
@ -1,8 +1,11 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import wandb
|
import wandb
|
||||||
from coati.distributed.consumer import BaseConsumer
|
from coati.distributed.consumer import BaseConsumer
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
@ -33,6 +36,8 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
microbatch_size=1,
|
microbatch_size=1,
|
||||||
num_generations=4,
|
num_generations=4,
|
||||||
use_wandb=True,
|
use_wandb=True,
|
||||||
|
generator_config=None,
|
||||||
|
filter_range=None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
@ -69,6 +74,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||||
self.pad_token_id = self.tokenizer.pad_token_id
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
self.num_generations = num_generations
|
self.num_generations = num_generations
|
||||||
|
self.filter_range = filter_range
|
||||||
|
if self.filter_range is not None:
|
||||||
|
assert len(self.filter_range) == 2, "Filter range should have 2 values."
|
||||||
|
|
||||||
# Initialize verifiable reward.
|
# Initialize verifiable reward.
|
||||||
response_format_tags = {
|
response_format_tags = {
|
||||||
@ -84,7 +92,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_loss_fn = PolicyLoss()
|
self.policy_loss_fn = PolicyLoss()
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
if use_wandb and self.rank == 0:
|
if use_wandb and self.rank == 0:
|
||||||
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
|
if "repetition_penalty" in generator_config:
|
||||||
|
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}_rep_penalty_{generator_config['repetition_penalty']:.01f}"
|
||||||
|
else:
|
||||||
|
name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}"
|
||||||
|
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
@ -121,7 +133,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
attention_mask=data["attention_mask"],
|
attention_mask=data["attention_mask"],
|
||||||
)["logits"]
|
)["logits"]
|
||||||
action_log_probs = calc_action_log_probs(
|
action_log_probs = calc_action_log_probs(
|
||||||
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
|
policy_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -130,7 +142,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
attention_mask=data["attention_mask"],
|
attention_mask=data["attention_mask"],
|
||||||
)["logits"]
|
)["logits"]
|
||||||
reference_action_log_probs = calc_action_log_probs(
|
reference_action_log_probs = calc_action_log_probs(
|
||||||
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
|
reference_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config
|
||||||
)
|
)
|
||||||
|
|
||||||
per_token_kl = (
|
per_token_kl = (
|
||||||
@ -149,7 +161,14 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
||||||
|
|
||||||
# [batch_size, num_generations]
|
# [batch_size, num_generations]
|
||||||
|
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||||
|
loss_mask = (
|
||||||
|
None
|
||||||
|
if self.filter_range is None
|
||||||
|
else torch.logical_and(reward > self.filter_range[0], reward < self.filter_range[1])
|
||||||
|
)
|
||||||
group_reward = reward.view(-1, self.num_generations)
|
group_reward = reward.view(-1, self.num_generations)
|
||||||
|
reward_mean = group_reward.mean(dim=1)
|
||||||
|
|
||||||
# [batch_size x num_generations]
|
# [batch_size x num_generations]
|
||||||
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||||
@ -164,6 +183,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
|
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||||
per_token_kl,
|
per_token_kl,
|
||||||
action_mask,
|
action_mask,
|
||||||
|
loss_mask=loss_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not skip_update:
|
if not skip_update:
|
||||||
@ -232,3 +252,125 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
model = self.policy_model.unwrap()
|
model = self.policy_model.unwrap()
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class GRPOEvalConsumer(BaseConsumer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_producers,
|
||||||
|
num_episodes,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
num_update_per_episode,
|
||||||
|
num_recv_per_update,
|
||||||
|
batch_size,
|
||||||
|
model_config,
|
||||||
|
plugin_config,
|
||||||
|
microbatch_size=1,
|
||||||
|
num_generations=4,
|
||||||
|
use_wandb=True,
|
||||||
|
log_dir="./results",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
num_producers,
|
||||||
|
num_episodes,
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
num_update_per_episode,
|
||||||
|
num_recv_per_update,
|
||||||
|
batch_size,
|
||||||
|
model_config,
|
||||||
|
plugin_config,
|
||||||
|
microbatch_size,
|
||||||
|
)
|
||||||
|
path = model_config.pop("path")
|
||||||
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
|
self.policy_model.train()
|
||||||
|
self.accum_reward = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_format_reward = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_acc_reward = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_count = torch.zeros(1, device=self.device)
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||||
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
self.num_generations = num_generations
|
||||||
|
|
||||||
|
# Initialize verifiable reward.
|
||||||
|
response_format_tags = {
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
self.reward_model = VerifiableReward(
|
||||||
|
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_dir = log_dir
|
||||||
|
if not os.path.exists(self.log_dir):
|
||||||
|
os.makedirs(self.log_dir)
|
||||||
|
else:
|
||||||
|
os.system(f"rm -rf {self.log_dir}/*")
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
super().setup()
|
||||||
|
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
|
||||||
|
|
||||||
|
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||||
|
rank = dist.get_rank()
|
||||||
|
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
|
||||||
|
kwargs["input_ids"].size(0)
|
||||||
|
reward_group = self.reward_model(
|
||||||
|
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||||
|
)
|
||||||
|
reward = [value[0].item() for value in reward_group]
|
||||||
|
format_reward = [value[1].item() for value in reward_group]
|
||||||
|
acc_reward = [value[2].item() for value in reward_group]
|
||||||
|
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
|
||||||
|
|
||||||
|
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
|
||||||
|
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
|
||||||
|
for i in range(len(response)):
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"response": response[i],
|
||||||
|
"reward": reward[i],
|
||||||
|
"format_reward": format_reward[i],
|
||||||
|
"acc_reward": acc_reward[i],
|
||||||
|
"response_length": response_length[i],
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.accum_reward += sum(reward)
|
||||||
|
self.accum_format_reward += sum(format_reward)
|
||||||
|
self.accum_acc_reward += sum(acc_reward)
|
||||||
|
self.accum_response_length += sum(response_length)
|
||||||
|
self.accum_count += len(reward)
|
||||||
|
|
||||||
|
# print results
|
||||||
|
total_count = all_reduce_mean(self.accum_count, self.plugin)
|
||||||
|
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
|
||||||
|
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
|
||||||
|
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
|
||||||
|
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
|
||||||
|
if rank == 0:
|
||||||
|
print(
|
||||||
|
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
self.policy_model._force_wait_all_gather()
|
||||||
|
model = self.policy_model.unwrap()
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
return state_dict
|
||||||
|
@ -183,7 +183,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
raise ImportError("vllm is not installed")
|
raise ImportError("vllm is not installed")
|
||||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
self.llm = LLM(path, **model_config)
|
self.llm = LLM(model=path, **model_config)
|
||||||
generate_config = generate_config.copy()
|
generate_config = generate_config.copy()
|
||||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
self.generate_config = SamplingParams(**generate_config)
|
self.generate_config = SamplingParams(**generate_config)
|
||||||
@ -194,8 +194,15 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
micro_batch_size = input_ids.size(0)
|
micro_batch_size = input_ids.size(0)
|
||||||
response_start_idx = input_ids.size(1)
|
response_start_idx = input_ids.size(1)
|
||||||
|
micro_batch_input_ids = input_ids.tolist()
|
||||||
|
micro_batch_input_ids_no_padding = []
|
||||||
|
for i in range(micro_batch_size):
|
||||||
|
for j in range(input_ids.size(1)):
|
||||||
|
if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id:
|
||||||
|
micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:])
|
||||||
|
break
|
||||||
outputs = self.llm.generate(
|
outputs = self.llm.generate(
|
||||||
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
|
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
|
||||||
)
|
)
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
out_len = []
|
out_len = []
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
|
import copy
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
|
||||||
from .consumer import SimpleConsumer
|
from .consumer import SimpleConsumer
|
||||||
from .grpo_consumer import GRPOConsumer
|
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
|
||||||
from .producer import SimpleProducer
|
from .producer import SimpleProducer
|
||||||
|
|
||||||
ALGO_MAP = {
|
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
|
||||||
"Simple": SimpleConsumer,
|
|
||||||
"GRPO": GRPOConsumer,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_jsonl_size_fast(path: str) -> int:
|
def get_jsonl_size_fast(path: str) -> int:
|
||||||
@ -80,6 +78,12 @@ def launch_distributed(
|
|||||||
backend=inference_backend,
|
backend=inference_backend,
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
procs.append(producer)
|
||||||
|
generate_config_consumer = copy.deepcopy(generate_config)
|
||||||
|
generate_config_consumer.update(
|
||||||
|
dict(
|
||||||
|
backend=inference_backend,
|
||||||
|
)
|
||||||
|
)
|
||||||
for i in range(num_consumer_procs):
|
for i in range(num_consumer_procs):
|
||||||
consumer = core_consumer.options(num_gpus=1).remote(
|
consumer = core_consumer.options(num_gpus=1).remote(
|
||||||
num_producers=num_producers,
|
num_producers=num_producers,
|
||||||
@ -94,6 +98,8 @@ def launch_distributed(
|
|||||||
model_config=train_model_config,
|
model_config=train_model_config,
|
||||||
plugin_config=plugin_config,
|
plugin_config=plugin_config,
|
||||||
microbatch_size=train_microbatch_size,
|
microbatch_size=train_microbatch_size,
|
||||||
|
generate_config=generate_config_consumer,
|
||||||
|
filter_range=[0.05, 9.0],
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in procs])
|
||||||
|
@ -23,6 +23,7 @@ class PolicyLoss(nn.Module):
|
|||||||
advantages: torch.Tensor,
|
advantages: torch.Tensor,
|
||||||
per_token_kl: torch.Tensor,
|
per_token_kl: torch.Tensor,
|
||||||
action_mask: Optional[torch.Tensor] = None,
|
action_mask: Optional[torch.Tensor] = None,
|
||||||
|
loss_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
skip = False
|
skip = False
|
||||||
if action_mask is None:
|
if action_mask is None:
|
||||||
@ -38,5 +39,7 @@ class PolicyLoss(nn.Module):
|
|||||||
loss = masked_mean(loss, action_mask)
|
loss = masked_mean(loss, action_mask)
|
||||||
else:
|
else:
|
||||||
loss = loss.mean(dim=1)
|
loss = loss.mean(dim=1)
|
||||||
|
if loss_mask is not None:
|
||||||
|
loss = loss * loss_mask
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
return loss, skip, ratio.max()
|
return loss, skip, ratio.max()
|
||||||
|
@ -15,18 +15,14 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
|
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
|
||||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
|
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
|
||||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"])
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ray.init(address="local", namespace="ray-example")
|
ray.init(address="local", namespace="ray-example")
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(path=args.model)
|
train_model_config = dict(path=args.model)
|
||||||
generate_config = dict(
|
generate_config = dict(top_k=50, top_p=0.9, temperature=0.7)
|
||||||
top_k=50,
|
|
||||||
top_p=0.9,
|
|
||||||
temperature=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
inference_model_config.update(
|
inference_model_config.update(
|
||||||
@ -52,19 +48,13 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
inference_model_config.update(
|
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
||||||
dict(
|
|
||||||
gpu_memory_utilization=0.7,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"],
|
stop=["</answer>"],
|
||||||
temperature=0.7,
|
|
||||||
top_p=0.95,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -97,6 +87,6 @@ if __name__ == "__main__":
|
|||||||
plugin_config={},
|
plugin_config={},
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=29504,
|
master_port=29503,
|
||||||
core_algo=args.algo,
|
core_algo=args.algo,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user