mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[feat] GRPO with distributed implementation (#6230)
* add reward related function * add simple grpo * update grpo * polish * modify data loader * grpo consumer * update loss * update reward fn * update example * update loader * add algo selection * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add save * update select algo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update grpo * update reward fn * update reward * fix reward score * add response length * detach * fix tp bug * fix consumer * convert to 8 generation * print results * setup update * fix transformers backend * [Feature] Support Distributed LogProb for GRPO Training (#6247) * [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * fix vllm * fix logprob, add filtering, temperature annealing, lr descent * simplify vllm preprocessing input ids * update logging * [feat] add microbatch forwarding (#6251) * add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Distributed RLHF] Integration of PP (#6257) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> * [hot-fix] Fix memory leakage bug, support TP+PP (#6258) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
This commit is contained in:
@@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
)
|
||||
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
|
||||
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
):
|
||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||
model_config.update(self.FORCE_MODEL_CONFIG)
|
||||
path = model_config.pop("path")
|
||||
@@ -61,12 +67,22 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
self.generate_config = generate_config.copy()
|
||||
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
micro_batch_size = input_ids.size(0)
|
||||
input_ids = input_ids.to(get_current_device())
|
||||
attention_mask = attention_mask.to(get_current_device())
|
||||
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
|
||||
gt_answer = None
|
||||
if "gt_answer" in kwargs:
|
||||
gt_answer = kwargs.pop("gt_answer")
|
||||
if self.num_generations > 1:
|
||||
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
|
||||
out = self.model.generate(
|
||||
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
|
||||
)
|
||||
input_len = input_ids.shape[-1]
|
||||
new_token_ids = out.sequences[:, input_len:]
|
||||
# get log probs
|
||||
@@ -76,10 +92,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
|
||||
action_log_probs = torch.cat(action_log_probs, dim=1)
|
||||
# get action mask
|
||||
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
|
||||
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
|
||||
if self.tokenizer.eos_token_id is not None:
|
||||
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
|
||||
action_mask[indices[0], indices[1] + 1 :] = 0
|
||||
response_idx[:, 0] = input_len
|
||||
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
|
||||
|
||||
if attention_mask.size(0) != action_mask.size(0):
|
||||
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||
@@ -91,7 +110,15 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
"attention_mask": attention_mask,
|
||||
"action_log_probs": action_log_probs,
|
||||
"action_mask": action_mask,
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
|
||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
||||
|
||||
if gt_answer is not None:
|
||||
# repeat gt_answer for each prompt.
|
||||
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
@@ -99,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
||||
|
||||
|
||||
class SGLangInferenceBackend(BaseInferenceBackend):
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
):
|
||||
if sgl is None:
|
||||
raise ImportError("sglang is not installed")
|
||||
path = model_config.pop("path")
|
||||
@@ -156,29 +189,46 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
logprobs=0,
|
||||
)
|
||||
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_generations: int = 8,
|
||||
):
|
||||
if LLM is None:
|
||||
raise ImportError("vllm is not installed")
|
||||
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
|
||||
path = model_config.pop("path")
|
||||
self.llm = LLM(path, **model_config)
|
||||
self.llm = LLM(model=path, **model_config)
|
||||
generate_config = generate_config.copy()
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
generate_config.update({"n": num_generations})
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
micro_batch_size = input_ids.size(0)
|
||||
response_start_idx = input_ids.size(1)
|
||||
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
|
||||
micro_batch_input_ids = input_ids.tolist()
|
||||
micro_batch_input_ids_no_padding = [
|
||||
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
|
||||
]
|
||||
outputs = self.llm.generate(
|
||||
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
|
||||
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
|
||||
)
|
||||
out_tokens = []
|
||||
out_len = []
|
||||
log_probs = []
|
||||
response_idx = []
|
||||
for out in outputs:
|
||||
for output_i in out.outputs:
|
||||
out_len.append(len(output_i.token_ids))
|
||||
out_tokens.append(list(output_i.token_ids))
|
||||
response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
|
||||
assert len(output_i.logprobs) == len(output_i.token_ids)
|
||||
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
|
||||
log_probs.append(p)
|
||||
@@ -195,6 +245,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
|
||||
out_tokens = torch.tensor(out_tokens)
|
||||
log_probs = torch.tensor(log_probs)
|
||||
response_idx = torch.tensor(response_idx)
|
||||
|
||||
if attention_mask.size(0) != action_mask.size(0):
|
||||
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||
num_returns = action_mask.size(0) // attention_mask.size(0)
|
||||
@@ -209,7 +261,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
"attention_mask": attention_mask,
|
||||
"action_log_probs": log_probs,
|
||||
"action_mask": action_mask,
|
||||
"response_idx": response_idx,
|
||||
}
|
||||
|
||||
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
|
||||
|
||||
if "gt_answer" in kwargs:
|
||||
# repeat gt_answer for each prompt.
|
||||
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
|
||||
data = {k: v.to(get_current_device()) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
|
Reference in New Issue
Block a user