mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 01:39:26 +00:00
166 lines
6.0 KiB
Python
166 lines
6.0 KiB
Python
from typing import Any, Dict, Optional
|
|
|
|
import ray
|
|
import ray.util.collective as cc
|
|
import torch
|
|
from coati.dataset.loader import RawConversationDataset
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
from transformers import AutoTokenizer
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from .comm import ray_broadcast_tensor_dict
|
|
from .inference_backend import BACKEND_MAP
|
|
from .utils import pre_send
|
|
|
|
|
|
class BaseProducer:
|
|
def __init__(
|
|
self,
|
|
producer_idx: int,
|
|
num_producers: int,
|
|
num_consumer_procs: int,
|
|
num_episodes: int,
|
|
batch_size: int,
|
|
dataset_config: Dict[str, Any],
|
|
dataloaders_config: Dict[str, Any],
|
|
model_config: Dict[str, Any],
|
|
generate_config: Dict[str, Any],
|
|
tokenizer_config: Optional[Dict[str, Any]] = None,
|
|
microbatch_size: int = 1,
|
|
backend: str = "transformers",
|
|
):
|
|
self.producer_idx = producer_idx
|
|
self.num_producers = num_producers
|
|
self.num_consumer_procs = num_consumer_procs
|
|
self.num_episodes = num_episodes
|
|
self.batch_size = batch_size
|
|
self.microbatch_size = microbatch_size
|
|
assert batch_size % microbatch_size == 0
|
|
self.num_microbatches = batch_size // microbatch_size
|
|
|
|
self.dataset_config = dataset_config
|
|
self.model_config = model_config
|
|
self.generate_config = generate_config
|
|
self.tokenizer_config = tokenizer_config
|
|
|
|
# init tokenizer
|
|
if tokenizer_config is None:
|
|
tokenizer_path = model_config["path"]
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
else:
|
|
tokenizer_path = tokenizer_config.pop("path")
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
|
|
self.tokenizer.padding_side = "left"
|
|
|
|
# init dataloader
|
|
dataset_path = dataset_config.pop("path")
|
|
self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config)
|
|
self.dataloader = DataLoader(
|
|
self.dataset,
|
|
batch_size=microbatch_size,
|
|
sampler=DistributedSampler(
|
|
self.dataset,
|
|
num_replicas=num_producers,
|
|
rank=producer_idx,
|
|
shuffle=True,
|
|
drop_last=True,
|
|
seed=42,
|
|
),
|
|
num_workers=4,
|
|
)
|
|
self.device = get_current_device()
|
|
|
|
# init backend
|
|
if backend in BACKEND_MAP:
|
|
self.backend_cls = BACKEND_MAP[backend]
|
|
else:
|
|
raise ValueError(f"Unexpected backend {backend}")
|
|
|
|
def setup(self) -> None:
|
|
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
|
|
|
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
|
raise NotImplementedError
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
|
raise NotImplementedError
|
|
|
|
def loop(self) -> None:
|
|
num_update_per_episode = len(self.dataloader) // self.num_microbatches
|
|
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
|
|
|
print(
|
|
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
|
|
)
|
|
for episode in range(self.num_episodes):
|
|
self.dataloader.sampler.set_epoch(episode)
|
|
for i, batch in enumerate(self.dataloader):
|
|
if i >= num_valid_microbatches:
|
|
break
|
|
outputs = self.rollout(**batch)
|
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
|
outputs = pre_send(outputs)
|
|
ray_broadcast_tensor_dict(
|
|
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
|
)
|
|
|
|
if (i + 1) % self.num_microbatches == 0 and (
|
|
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
|
):
|
|
# don't sync model for last iteration
|
|
print(
|
|
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
|
)
|
|
state_dict = ray_broadcast_tensor_dict(
|
|
None, self.num_producers, device=self.device, group_name="sync_model"
|
|
)
|
|
self.load_state_dict(state_dict)
|
|
|
|
|
|
@ray.remote
|
|
class SimpleProducer(BaseProducer):
|
|
def __init__(
|
|
self,
|
|
producer_idx,
|
|
num_producers,
|
|
num_consumer_procs,
|
|
num_episodes,
|
|
batch_size,
|
|
dataset_config,
|
|
dataloaders_config,
|
|
model_config,
|
|
generate_config,
|
|
tokenizer_config=None,
|
|
microbatch_size=1,
|
|
backend="transformers",
|
|
):
|
|
super().__init__(
|
|
producer_idx,
|
|
num_producers,
|
|
num_consumer_procs,
|
|
num_episodes,
|
|
batch_size,
|
|
dataset_config,
|
|
dataloaders_config,
|
|
model_config,
|
|
generate_config,
|
|
tokenizer_config,
|
|
microbatch_size,
|
|
backend,
|
|
)
|
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
|
|
|
|
@torch.no_grad()
|
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
|
if self.backend_cls.__name__ == "TransformersInferenceBackend":
|
|
gt_answer = kwargs.pop("gt_answer")
|
|
out = self.model.generate(input_ids, attention_mask, **kwargs)
|
|
out["gt_answer"] = gt_answer.to(out["input_ids"].device)
|
|
return out
|
|
return self.model.generate(input_ids, attention_mask, **kwargs)
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.model.load_state_dict(state_dict)
|