from typing import Any, Dict, List, Optional import ray import ray.util.collective as cc import torch from coati.dataset.loader import RawConversationDataset from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict from .inference_backend import BACKEND_MAP from .utils import pre_send class BaseProducer: def __init__( self, producer_idx: int, num_producers: int, num_consumer_procs: int, num_episodes: int, batch_size: int, dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], model_config: Dict[str, Any], generate_config: Dict[str, Any], consumer_plugin_config: Dict[str, Any] = None, tokenizer_config: Optional[Dict[str, Any]] = None, microbatch_size: int = 1, backend: str = "transformers", ): self.producer_idx = producer_idx self.num_producers = num_producers self.num_consumer_procs = num_consumer_procs self.num_episodes = num_episodes self.batch_size = batch_size self.microbatch_size = microbatch_size assert batch_size % microbatch_size == 0 self.num_microbatches = batch_size // microbatch_size self.dataset_config = dataset_config self.model_config = model_config self.generate_config = generate_config self.tokenizer_config = tokenizer_config self.consumer_plugin_config = consumer_plugin_config # init tokenizer if tokenizer_config is None: tokenizer_path = model_config["path"] self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) else: tokenizer_path = tokenizer_config.pop("path") self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config) self.tokenizer.padding_side = "left" # init dataloader dataset_path = dataset_config.pop("path") self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config) self.dataloader = DataLoader( self.dataset, batch_size=microbatch_size, sampler=DistributedSampler( self.dataset, num_replicas=num_producers, rank=producer_idx, shuffle=True, drop_last=True, seed=42, ), num_workers=4, drop_last=True, ) self.device = get_current_device() # init backend if backend in BACKEND_MAP: self.backend_cls = BACKEND_MAP[backend] else: raise ValueError(f"Unexpected backend {backend}") def setup(self, model_state_dict_keys: List = None) -> None: self.model_state_dict_keys = model_state_dict_keys cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") for pp_stage in range(self.consumer_plugin_config.get("pp_size", 1)): group_name = f"sync_model_pp_stage_{pp_stage}" cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=group_name) def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: raise NotImplementedError def loop(self) -> None: num_update_per_episode = len(self.dataloader) // self.num_microbatches num_valid_microbatches = num_update_per_episode * self.num_microbatches print( f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}" ) for episode in range(self.num_episodes): self.dataloader.sampler.set_epoch(episode) for i, batch in enumerate(self.dataloader): if i >= num_valid_microbatches: break outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( [ ( self.model.generate_config["temperature"] if isinstance(self.model.generate_config.temperature, dict) else self.model.generate_config.temperature ) ] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) outputs = pre_send(outputs) ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" ) if (i + 1) % self.num_microbatches == 0 and ( episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 ): if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( "enable_sleep_mode", False ): self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration torch.cuda.empty_cache() state_dict = {} print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) for pp_stage in range(self.consumer_plugin_config.get("pp_size", 1)): group_name = f"sync_model_pp_stage_{pp_stage}" state_dict.update( ray_broadcast_tensor_dict( None, src=self.num_producers, device=self.device, group_name=group_name ) ) # check model sync integrity assert len(state_dict) == len( self.model_state_dict_keys ), f"state dict keys has {len(state_dict)} unique keys not equal original model with {len(self.model_state_dict_keys)} keys. Missing keys: {set(self.model_state_dict_keys)-set(state_dict.keys())}. Please kindly inform the developer." self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( "enable_sleep_mode", False ): self.model.llm.wake_up() # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) if isinstance(self.model.generate_config.temperature, dict): self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ "temperature" ] + ratio * 0.9 else: self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ "temperature" ] + ratio * 0.9 @ray.remote class SimpleProducer(BaseProducer): def __init__( self, producer_idx, num_producers, num_consumer_procs, num_episodes, batch_size, dataset_config, dataloaders_config, model_config, generate_config, consumer_plugin_config=None, tokenizer_config=None, microbatch_size=1, backend="transformers", num_generations: int = 8, ): super().__init__( producer_idx, num_producers, num_consumer_procs, num_episodes, batch_size, dataset_config, dataloaders_config, model_config, generate_config, consumer_plugin_config, tokenizer_config, microbatch_size, backend, ) self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): rollouts = self.model.generate(input_ids, attention_mask, **kwargs) if self.producer_idx == 1: print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) return rollouts def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict)