mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-12 21:25:53 +00:00
216 lines
8.8 KiB
Python
216 lines
8.8 KiB
Python
from typing import Any, Dict, List, Optional
|
|
|
|
import ray
|
|
import ray.util.collective as cc
|
|
import torch
|
|
from coati.dataset.loader import RawConversationDataset
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
from transformers import AutoTokenizer
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from .comm import ray_broadcast_tensor_dict
|
|
from .inference_backend import BACKEND_MAP
|
|
from .utils import pre_send
|
|
|
|
|
|
class BaseProducer:
|
|
def __init__(
|
|
self,
|
|
producer_idx: int,
|
|
num_producers: int,
|
|
num_consumer_procs: int,
|
|
num_episodes: int,
|
|
batch_size: int,
|
|
dataset_config: Dict[str, Any],
|
|
dataloaders_config: Dict[str, Any],
|
|
model_config: Dict[str, Any],
|
|
generate_config: Dict[str, Any],
|
|
consumer_plugin_config: Dict[str, Any] = None,
|
|
tokenizer_config: Optional[Dict[str, Any]] = None,
|
|
microbatch_size: int = 1,
|
|
backend: str = "transformers",
|
|
):
|
|
self.producer_idx = producer_idx
|
|
self.num_producers = num_producers
|
|
self.num_consumer_procs = num_consumer_procs
|
|
self.num_episodes = num_episodes
|
|
self.batch_size = batch_size
|
|
self.microbatch_size = microbatch_size
|
|
assert batch_size % microbatch_size == 0
|
|
self.num_microbatches = batch_size // microbatch_size
|
|
|
|
self.dataset_config = dataset_config
|
|
self.model_config = model_config
|
|
self.generate_config = generate_config
|
|
self.tokenizer_config = tokenizer_config
|
|
self.consumer_plugin_config = consumer_plugin_config
|
|
|
|
# init tokenizer
|
|
if tokenizer_config is None:
|
|
tokenizer_path = model_config["path"]
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
else:
|
|
tokenizer_path = tokenizer_config.pop("path")
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
|
|
self.tokenizer.padding_side = "left"
|
|
|
|
# init dataloader
|
|
dataset_path = dataset_config.pop("path")
|
|
self.dataset = RawConversationDataset(self.tokenizer, dataset_path, **dataset_config)
|
|
self.dataloader = DataLoader(
|
|
self.dataset,
|
|
batch_size=microbatch_size,
|
|
sampler=DistributedSampler(
|
|
self.dataset,
|
|
num_replicas=num_producers,
|
|
rank=producer_idx,
|
|
shuffle=True,
|
|
drop_last=True,
|
|
seed=42,
|
|
),
|
|
num_workers=4,
|
|
drop_last=True,
|
|
)
|
|
self.device = get_current_device()
|
|
|
|
# init backend
|
|
if backend in BACKEND_MAP:
|
|
self.backend_cls = BACKEND_MAP[backend]
|
|
else:
|
|
raise ValueError(f"Unexpected backend {backend}")
|
|
|
|
def setup(self, model_state_dict_keys: List = None) -> None:
|
|
self.model_state_dict_keys = model_state_dict_keys
|
|
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
|
for pp_stage in range(self.consumer_plugin_config.get("pp_size", 1)):
|
|
group_name = f"sync_model_pp_stage_{pp_stage}"
|
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=group_name)
|
|
|
|
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
|
raise NotImplementedError
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
|
|
raise NotImplementedError
|
|
|
|
def loop(self) -> None:
|
|
num_update_per_episode = len(self.dataloader) // self.num_microbatches
|
|
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
|
|
|
print(
|
|
f"[P{self.producer_idx}] num_valid_microbatches {num_valid_microbatches}, nmb: {self.num_microbatches}, dl: {len(self.dataloader)}"
|
|
)
|
|
for episode in range(self.num_episodes):
|
|
self.dataloader.sampler.set_epoch(episode)
|
|
for i, batch in enumerate(self.dataloader):
|
|
if i >= num_valid_microbatches:
|
|
break
|
|
outputs = self.rollout(**batch)
|
|
|
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
|
outputs["temperature"] = torch.tensor(
|
|
[
|
|
(
|
|
self.model.generate_config["temperature"]
|
|
if isinstance(self.model.generate_config.temperature, dict)
|
|
else self.model.generate_config.temperature
|
|
)
|
|
]
|
|
* outputs["input_ids"].size(0)
|
|
).to(outputs["input_ids"].device)
|
|
outputs = pre_send(outputs)
|
|
ray_broadcast_tensor_dict(
|
|
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
|
)
|
|
if (i + 1) % self.num_microbatches == 0 and (
|
|
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
|
):
|
|
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
|
"enable_sleep_mode", False
|
|
):
|
|
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
|
# don't sync model for last iteration
|
|
torch.cuda.empty_cache()
|
|
state_dict = {}
|
|
print(
|
|
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
|
)
|
|
for pp_stage in range(self.consumer_plugin_config.get("pp_size", 1)):
|
|
group_name = f"sync_model_pp_stage_{pp_stage}"
|
|
state_dict.update(
|
|
ray_broadcast_tensor_dict(
|
|
None, src=self.num_producers, device=self.device, group_name=group_name
|
|
)
|
|
)
|
|
# check model sync integrity
|
|
assert len(state_dict) == len(
|
|
self.model_state_dict_keys
|
|
), f"state dict keys has {len(state_dict)} unique keys not equal original model with {len(self.model_state_dict_keys)} keys. Missing keys: {set(self.model_state_dict_keys)-set(state_dict.keys())}. Please kindly inform the developer."
|
|
|
|
self.load_state_dict(state_dict)
|
|
del state_dict
|
|
torch.cuda.empty_cache()
|
|
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
|
"enable_sleep_mode", False
|
|
):
|
|
self.model.llm.wake_up()
|
|
# linear annealing for 1 episode, temperature from initial to 0.9
|
|
if episode <= 0:
|
|
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
|
if isinstance(self.model.generate_config.temperature, dict):
|
|
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
|
"temperature"
|
|
] + ratio * 0.9
|
|
else:
|
|
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
|
"temperature"
|
|
] + ratio * 0.9
|
|
|
|
|
|
@ray.remote
|
|
class SimpleProducer(BaseProducer):
|
|
def __init__(
|
|
self,
|
|
producer_idx,
|
|
num_producers,
|
|
num_consumer_procs,
|
|
num_episodes,
|
|
batch_size,
|
|
dataset_config,
|
|
dataloaders_config,
|
|
model_config,
|
|
generate_config,
|
|
consumer_plugin_config=None,
|
|
tokenizer_config=None,
|
|
microbatch_size=1,
|
|
backend="transformers",
|
|
num_generations: int = 8,
|
|
):
|
|
super().__init__(
|
|
producer_idx,
|
|
num_producers,
|
|
num_consumer_procs,
|
|
num_episodes,
|
|
batch_size,
|
|
dataset_config,
|
|
dataloaders_config,
|
|
model_config,
|
|
generate_config,
|
|
consumer_plugin_config,
|
|
tokenizer_config,
|
|
microbatch_size,
|
|
backend,
|
|
)
|
|
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
|
|
|
|
@torch.no_grad()
|
|
def rollout(self, input_ids, attention_mask, **kwargs):
|
|
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
|
|
if self.producer_idx == 1:
|
|
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
|
|
|
|
return rollouts
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.model.load_state_dict(state_dict)
|