[feature] fit RL style generation (#6213)

* [feature] fit rl style generation

* [doc] add docstr

* [doc] add docstr
This commit is contained in:
Hongxin Liu
2025-02-21 17:28:19 +08:00
committed by GitHub
parent 43c9b5fb44
commit de282dd694
3 changed files with 140 additions and 49 deletions

View File

@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Any, Dict, List
import torch
@@ -25,16 +25,42 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# compress attention_mask to save bandwidth
# compress mask to save bandwidth
if "attention_mask" in batch:
attention_mask = batch["attention_mask"]
batch["attention_mask"] = attention_mask.to(torch.bool)
batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
if "action_mask" in batch:
batch["action_mask"] = batch["action_mask"].to(torch.bool)
return batch
def post_recv(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# decompress attention_mask
# decompress mask
if "attention_mask" in batch:
attention_mask = batch["attention_mask"]
batch["attention_mask"] = attention_mask.to(torch.int)
batch["attention_mask"] = batch["attention_mask"].to(torch.int)
if "action_mask" in batch:
batch["action_mask"] = batch["action_mask"].to(torch.int)
return batch
def update_by_default(data: Dict[str, Any], default: Dict[str, Any]) -> Dict[str, Any]:
data = data.copy()
for k, v in default.items():
if k not in data:
data[k] = v
return data
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Compute the log probabilities from logits for the given labels.
Args:
logits (torch.Tensor): The input logits.
labels (torch.Tensor): The target labels.
Returns:
torch.Tensor: The log probabilities corresponding to the labels.
"""
log_probs = torch.log_softmax(logits, dim=-1)
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return per_label_logps.squeeze(-1)