mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[feature] fit RL style generation (#6213)
* [feature] fit rl style generation * [doc] add docstr * [doc] add docstr
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
@@ -25,16 +25,42 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
|
||||
|
||||
|
||||
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# compress attention_mask to save bandwidth
|
||||
# compress mask to save bandwidth
|
||||
if "attention_mask" in batch:
|
||||
attention_mask = batch["attention_mask"]
|
||||
batch["attention_mask"] = attention_mask.to(torch.bool)
|
||||
batch["attention_mask"] = batch["attention_mask"].to(torch.bool)
|
||||
if "action_mask" in batch:
|
||||
batch["action_mask"] = batch["action_mask"].to(torch.bool)
|
||||
return batch
|
||||
|
||||
|
||||
def post_recv(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# decompress attention_mask
|
||||
# decompress mask
|
||||
if "attention_mask" in batch:
|
||||
attention_mask = batch["attention_mask"]
|
||||
batch["attention_mask"] = attention_mask.to(torch.int)
|
||||
batch["attention_mask"] = batch["attention_mask"].to(torch.int)
|
||||
if "action_mask" in batch:
|
||||
batch["action_mask"] = batch["action_mask"].to(torch.int)
|
||||
return batch
|
||||
|
||||
|
||||
def update_by_default(data: Dict[str, Any], default: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data = data.copy()
|
||||
for k, v in default.items():
|
||||
if k not in data:
|
||||
data[k] = v
|
||||
return data
|
||||
|
||||
|
||||
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the log probabilities from logits for the given labels.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The input logits.
|
||||
labels (torch.Tensor): The target labels.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The log probabilities corresponding to the labels.
|
||||
"""
|
||||
log_probs = torch.log_softmax(logits, dim=-1)
|
||||
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return per_label_logps.squeeze(-1)
|
||||
|
Reference in New Issue
Block a user