mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-31 05:49:56 +00:00 
			
		
		
		
	[Feat]Inference RPC Server Support (#5705)
* rpc support source * kv cache logical/physical disaggregation * sampler refactor * colossalai launch built in * Unitest * Rpyc support --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		| @@ -2,11 +2,11 @@ | |||||||
| Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. | Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. | ||||||
| """ | """ | ||||||
| import logging | import logging | ||||||
|  | from abc import ABC, abstractmethod | ||||||
| from dataclasses import dataclass, fields | from dataclasses import dataclass, fields | ||||||
| from typing import Any, Dict, Optional, Union | from typing import Any, Dict, List, Optional, Union | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| import torch.distributed as dist |  | ||||||
| from transformers.generation import GenerationConfig | from transformers.generation import GenerationConfig | ||||||
|  |  | ||||||
| from colossalai.inference.flash_decoding_utils import FDIntermTensors | from colossalai.inference.flash_decoding_utils import FDIntermTensors | ||||||
| @@ -30,8 +30,25 @@ _DEFAULT_PROMPT_TEMPLATES = { | |||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RPC_PARAM(ABC): | ||||||
|  |     """ | ||||||
|  |     NOTE(lry89757) We use rpyc to transport param between client and server. | ||||||
|  |     Rpyc only support the type of `POD` in python as the param, so we should take some smart ways to transport the data like tensor or some sophisticated classes. | ||||||
|  |     Drawing on the logic of `__setstate__`, `__getstate__`, we will let some classes(will be rpc param later) inherit this base class, and rewrite the to_rpc_param and from_rpc_param. We will invoke `to_rpc_param` in client to pass the params and recover the param in server side by `from_rpc_param`. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     @abstractmethod | ||||||
|  |     def to_rpc_param(self): | ||||||
|  |         return NotImplementedError | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     @abstractmethod | ||||||
|  |     def from_rpc_param(): | ||||||
|  |         return NotImplementedError | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| class InputMetaData: | class InputMetaData(RPC_PARAM): | ||||||
|     """The input info for a single step |     """The input info for a single step | ||||||
|  |  | ||||||
|     Args: |     Args: | ||||||
| @@ -48,6 +65,7 @@ class InputMetaData: | |||||||
|     dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32. |     dtype (torch.dtype, optional): The computation type of tensor, Defaults to torch.float32. | ||||||
|     use_spec_dec (bool): Indicate whether to use speculative decoding. |     use_spec_dec (bool): Indicate whether to use speculative decoding. | ||||||
|     num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True. |     num_tokens_to_verify (int): The number of tokens to verify in speculative decoding. Only valid when `use_spec_dec` is set to True. | ||||||
|  |     batch_token_ids (List[List[int]], optional): input_token_ids + output_token_ids of current batch. Only used for `repetition_penalty`, `no_repeat_ngram_size` in sampler process. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     block_tables: torch.Tensor = None |     block_tables: torch.Tensor = None | ||||||
| @@ -63,6 +81,54 @@ class InputMetaData: | |||||||
|     dtype: torch.dtype = torch.float32 |     dtype: torch.dtype = torch.float32 | ||||||
|     use_spec_dec: bool = False |     use_spec_dec: bool = False | ||||||
|     num_tokens_to_verify: int = 0 |     num_tokens_to_verify: int = 0 | ||||||
|  |     batch_token_ids: Optional[ | ||||||
|  |         List[List[int]] | ||||||
|  |     ] = None  # for `repetition_penalty`, `no_repeat_ngram_size` in sampler process | ||||||
|  |  | ||||||
|  |     def to_rpc_param(self) -> Dict[str, any]: | ||||||
|  |         return { | ||||||
|  |             "block_tables": self.block_tables.tolist(), | ||||||
|  |             "sequence_lengths": self.sequence_lengths.tolist(), | ||||||
|  |             "batch_size": self.batch_size, | ||||||
|  |             "is_prompts": self.is_prompts, | ||||||
|  |             "use_cuda_kernel": self.use_cuda_kernel, | ||||||
|  |             "use_cuda_graph": self.use_cuda_graph, | ||||||
|  |             "kv_seq_len": self.kv_seq_len, | ||||||
|  |             "head_dim": self.head_dim, | ||||||
|  |             "high_precision": self.high_precision, | ||||||
|  |             "dtype": str(self.dtype).split(".")[-1], | ||||||
|  |             "use_spec_dec": self.use_spec_dec, | ||||||
|  |             "num_tokens_to_verify": self.num_tokens_to_verify, | ||||||
|  |             "batch_token_ids": self.batch_token_ids, | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData": | ||||||
|  |         """ | ||||||
|  |         We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message | ||||||
|  |         """ | ||||||
|  |         from colossalai.accelerator import get_accelerator | ||||||
|  |  | ||||||
|  |         dtype = getattr(torch, rpc_dict["dtype"]) | ||||||
|  |         return InputMetaData( | ||||||
|  |             block_tables=torch.tensor( | ||||||
|  |                 rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device() | ||||||
|  |             ), | ||||||
|  |             sequence_lengths=torch.tensor( | ||||||
|  |                 rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() | ||||||
|  |             ), | ||||||
|  |             batch_size=rpc_dict["batch_size"], | ||||||
|  |             is_prompts=rpc_dict["is_prompts"], | ||||||
|  |             use_cuda_kernel=rpc_dict["use_cuda_kernel"], | ||||||
|  |             use_cuda_graph=rpc_dict["use_cuda_graph"], | ||||||
|  |             kv_seq_len=rpc_dict["kv_seq_len"], | ||||||
|  |             head_dim=rpc_dict["head_dim"], | ||||||
|  |             high_precision=rpc_dict["high_precision"], | ||||||
|  |             dtype=dtype, | ||||||
|  |             use_spec_dec=rpc_dict["use_spec_dec"], | ||||||
|  |             num_tokens_to_verify=rpc_dict["num_tokens_to_verify"], | ||||||
|  |             batch_token_ids=rpc_dict["batch_token_ids"], | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def __repr__(self) -> str: |     def __repr__(self) -> str: | ||||||
|         return ( |         return ( | ||||||
| @@ -80,7 +146,7 @@ class InputMetaData: | |||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| class InferenceConfig: | class InferenceConfig(RPC_PARAM): | ||||||
|     """The inference configuration. |     """The inference configuration. | ||||||
|  |  | ||||||
|     Args: |     Args: | ||||||
| @@ -193,10 +259,6 @@ class InferenceConfig: | |||||||
|         if self.dtype == torch.float32: |         if self.dtype == torch.float32: | ||||||
|             self.high_precision = False |             self.high_precision = False | ||||||
|  |  | ||||||
|         # check distributed |  | ||||||
|         assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or ( |  | ||||||
|             self.tp_size * self.pp_size == dist.get_world_size() |  | ||||||
|         ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})" |  | ||||||
|         # check prompt template |         # check prompt template | ||||||
|         if self.prompt_template is None: |         if self.prompt_template is None: | ||||||
|             return |             return | ||||||
| @@ -226,6 +288,43 @@ class InferenceConfig: | |||||||
|  |  | ||||||
|         return GenerationConfig.from_dict(meta_config) |         return GenerationConfig.from_dict(meta_config) | ||||||
|  |  | ||||||
|  |     def to_rpc_param(self) -> dict: | ||||||
|  |         kwargs = { | ||||||
|  |             "dtype": str(self.dtype).split(".")[-1], | ||||||
|  |             "max_n_spec_tokens": self.max_n_spec_tokens, | ||||||
|  |             "max_batch_size": self.max_batch_size, | ||||||
|  |             "max_input_len": self.max_input_len, | ||||||
|  |             "max_output_len": self.max_output_len, | ||||||
|  |             "tp_size": self.tp_size, | ||||||
|  |             "pp_size": self.pp_size, | ||||||
|  |             "pad_input": self.pad_input, | ||||||
|  |             "early_stopping": self.early_stopping, | ||||||
|  |             "do_sample": self.do_sample, | ||||||
|  |             "beam_width": self.beam_width, | ||||||
|  |             "kv_cache_dtype": str(self.kv_cache_dtype).split(".")[-1], | ||||||
|  |         } | ||||||
|  |         return kwargs | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def from_rpc_param(rpc_dict: dict) -> "InferenceConfig": | ||||||
|  |         """ | ||||||
|  |         We intentionally don't use `dict.get` method to ensure we pass the right rpc param, or program will show error message | ||||||
|  |         """ | ||||||
|  |         return InferenceConfig( | ||||||
|  |             dtype=getattr(torch, rpc_dict["dtype"]), | ||||||
|  |             max_n_spec_tokens=rpc_dict["max_n_spec_tokens"], | ||||||
|  |             max_batch_size=rpc_dict["max_batch_size"], | ||||||
|  |             max_input_len=rpc_dict["max_input_len"], | ||||||
|  |             max_output_len=rpc_dict["max_output_len"], | ||||||
|  |             tp_size=rpc_dict["tp_size"], | ||||||
|  |             pp_size=rpc_dict["pp_size"], | ||||||
|  |             pad_input=rpc_dict["pad_input"], | ||||||
|  |             early_stopping=rpc_dict["early_stopping"], | ||||||
|  |             do_sample=rpc_dict["do_sample"], | ||||||
|  |             beam_width=rpc_dict["beam_width"], | ||||||
|  |             kv_cache_dtype=getattr(torch, rpc_dict["kv_cache_dtype"], None), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig": |     def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig": | ||||||
|         # Get the list of attributes of this dataclass. |         # Get the list of attributes of this dataclass. | ||||||
|   | |||||||
| @@ -21,6 +21,7 @@ from colossalai.inference.batch_bucket import BatchBucket | |||||||
| from colossalai.inference.config import InferenceConfig, InputMetaData | from colossalai.inference.config import InferenceConfig, InputMetaData | ||||||
| from colossalai.inference.graph_runner import CUDAGraphRunner | from colossalai.inference.graph_runner import CUDAGraphRunner | ||||||
| from colossalai.inference.modeling.policy import model_policy_map | from colossalai.inference.modeling.policy import model_policy_map | ||||||
|  | from colossalai.inference.sampler import search_tokens | ||||||
| from colossalai.inference.spec import Drafter, GlideInput | from colossalai.inference.spec import Drafter, GlideInput | ||||||
| from colossalai.inference.struct import Sequence | from colossalai.inference.struct import Sequence | ||||||
| from colossalai.inference.utils import get_model_size, has_index_file | from colossalai.inference.utils import get_model_size, has_index_file | ||||||
| @@ -424,7 +425,7 @@ class InferenceEngine: | |||||||
|  |  | ||||||
|         # 2. Prefill main model (Verifier) - fill past kv cache for main model |         # 2. Prefill main model (Verifier) - fill past kv cache for main model | ||||||
|         logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) |         logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) | ||||||
|         next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch) |         next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) | ||||||
|         # append new inputs to the batch, temporarily |         # append new inputs to the batch, temporarily | ||||||
|         batch.append_batch_tokens(next_tokens) |         batch.append_batch_tokens(next_tokens) | ||||||
|         self.request_handler.allocate_batch_spec_dec(batch, 1) |         self.request_handler.allocate_batch_spec_dec(batch, 1) | ||||||
| @@ -472,7 +473,7 @@ class InferenceEngine: | |||||||
|             input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) |             input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) | ||||||
|             logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) |             logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) | ||||||
|  |  | ||||||
|             next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch) |             next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) | ||||||
|  |  | ||||||
|             # 5. Compare and process the results |             # 5. Compare and process the results | ||||||
|             diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) |             diff_indexes = torch.nonzero(~(next_tokens[:-1] == next_token_ids_spec)) | ||||||
| @@ -689,6 +690,13 @@ class InferenceEngine: | |||||||
|             (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device |             (n_tokens, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |         batch_token_ids = None | ||||||
|  |         config_dict = self.generation_config.to_dict() | ||||||
|  |         # process repetition_penalty, no_repeat_ngram_size | ||||||
|  |         for type in ["repetition_penalty", "no_repeat_ngram_size"]: | ||||||
|  |             if type in config_dict and config_dict[type] is not None: | ||||||
|  |                 batch_token_ids = batch.batch_token_ids | ||||||
|  |  | ||||||
|         # only when we have the graph for specific decoding batch size can we use the cuda graph for inference |         # only when we have the graph for specific decoding batch size can we use the cuda graph for inference | ||||||
|         use_cuda_graph = False |         use_cuda_graph = False | ||||||
|         if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): |         if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): | ||||||
| @@ -708,6 +716,7 @@ class InferenceEngine: | |||||||
|             dtype=batch.dtype, |             dtype=batch.dtype, | ||||||
|             use_spec_dec=batch.use_spec_dec, |             use_spec_dec=batch.use_spec_dec, | ||||||
|             num_tokens_to_verify=batch.num_tokens_to_verify, |             num_tokens_to_verify=batch.num_tokens_to_verify, | ||||||
|  |             batch_token_ids=batch_token_ids, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         return input_ids, output_tensor, input_meta_data |         return input_ids, output_tensor, input_meta_data | ||||||
| @@ -738,7 +747,9 @@ class InferenceEngine: | |||||||
|         logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) |         logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) | ||||||
|         if self.inference_config.pad_input: |         if self.inference_config.pad_input: | ||||||
|             logits = logits[:, -1, :] |             logits = logits[:, -1, :] | ||||||
|         next_tokens = self.request_handler.search_tokens(self.generation_config, logits, batch) |         next_tokens = search_tokens( | ||||||
|  |             self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids | ||||||
|  |         ) | ||||||
|         self.request_handler.append_next_tokens(next_tokens) |         self.request_handler.append_next_tokens(next_tokens) | ||||||
|         finished_sequences = self.request_handler.update() |         finished_sequences = self.request_handler.update() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,10 +7,11 @@ from transformers.generation import GenerationConfig | |||||||
| from colossalai.inference.batch_bucket import BatchBucket | from colossalai.inference.batch_bucket import BatchBucket | ||||||
| from colossalai.inference.config import InferenceConfig | from colossalai.inference.config import InferenceConfig | ||||||
| from colossalai.inference.flash_decoding_utils import FDIntermTensors | from colossalai.inference.flash_decoding_utils import FDIntermTensors | ||||||
| from colossalai.inference.kv_cache import KVCacheManager | from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager | ||||||
| from colossalai.inference.logit_processors import logit_processor |  | ||||||
| from colossalai.inference.sampler import * |  | ||||||
| from colossalai.inference.struct import RequestStatus, Sequence | from colossalai.inference.struct import RequestStatus, Sequence | ||||||
|  | from colossalai.logging import get_dist_logger | ||||||
|  |  | ||||||
|  | logger = get_dist_logger(__name__) | ||||||
|  |  | ||||||
| __all__ = ["RunningList", "RequestHandler"] | __all__ = ["RunningList", "RequestHandler"] | ||||||
|  |  | ||||||
| @@ -295,17 +296,6 @@ class RequestHandler: | |||||||
|  |  | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     def _sample(self, probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig): |  | ||||||
|         if generation_config.num_beams == 1: |  | ||||||
|             if generation_config.do_sample: |  | ||||||
|                 sample_tokens = multinomial_sample(generation_config, probs) |  | ||||||
|             else: |  | ||||||
|                 sample_tokens = greedy_sample(generation_config, logprobs) |  | ||||||
|         else: |  | ||||||
|             sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=not self.prefill_bb.is_empty) |  | ||||||
|  |  | ||||||
|         return sample_tokens |  | ||||||
|  |  | ||||||
|     def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig): |     def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig): | ||||||
|         if ( |         if ( | ||||||
|             sequence.output_token_id[-1] == generation_config.eos_token_id |             sequence.output_token_id[-1] == generation_config.eos_token_id | ||||||
| @@ -328,33 +318,6 @@ class RequestHandler: | |||||||
|     def total_requests_in_batch_bucket(self) -> int: |     def total_requests_in_batch_bucket(self) -> int: | ||||||
|         return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size |         return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size | ||||||
|  |  | ||||||
|     def search_tokens(self, generation_config: GenerationConfig, logits, cur_batch: BatchBucket): |  | ||||||
|         """ |  | ||||||
|         Sample tokens for finished requests. |  | ||||||
|         """ |  | ||||||
|  |  | ||||||
|         # NOTE: need to decide the granularity to process logits (sequence or batch) |  | ||||||
|         config_dict = generation_config.to_dict() |  | ||||||
|         # process repetition_penalty, no_repeat_ngram_size |  | ||||||
|         for type in ["repetition_penalty", "no_repeat_ngram_size"]: |  | ||||||
|             if type in config_dict and config_dict[type] is not None: |  | ||||||
|                 logits = logit_processor(type, logits, config_dict[type], cur_batch) |  | ||||||
|  |  | ||||||
|         # do logit processor |  | ||||||
|         if generation_config.do_sample: |  | ||||||
|             # process temperature, top_k, top_p |  | ||||||
|             for type in ["temperature", "top_k", "top_p"]: |  | ||||||
|                 if type in config_dict and config_dict[type] is not None: |  | ||||||
|                     logits = logit_processor(type, logits, config_dict[type]) |  | ||||||
|  |  | ||||||
|         # calculate probs |  | ||||||
|         probs = torch.softmax(logits, dim=-1, dtype=torch.float) |  | ||||||
|         logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) |  | ||||||
|  |  | ||||||
|         # sample the next tokens |  | ||||||
|         sample_tokens = self._sample(probs, logprobs, generation_config) |  | ||||||
|         return sample_tokens |  | ||||||
|  |  | ||||||
|     def append_next_tokens(self, sample_tokens: torch.Tensor): |     def append_next_tokens(self, sample_tokens: torch.Tensor): | ||||||
|         assert sample_tokens.dim() == 1 |         assert sample_tokens.dim() == 1 | ||||||
|         n_elements = sample_tokens.size(0) |         n_elements = sample_tokens.size(0) | ||||||
| @@ -386,3 +349,53 @@ class RequestHandler: | |||||||
|         self.done_list.extend(finished_seqs) |         self.done_list.extend(finished_seqs) | ||||||
|  |  | ||||||
|         return finished_seqs |         return finished_seqs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RPCRequestHandler(RequestHandler): | ||||||
|  |     """ | ||||||
|  |     RPC Version of request handler | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None: | ||||||
|  |         self.inference_config = inference_config | ||||||
|  |         self.running_list: RunningList = RunningList(inference_config.prefill_ratio) | ||||||
|  |         self.waiting_list: List[List] = [[], [], []] | ||||||
|  |         self.done_list: List[Sequence] = [] | ||||||
|  |         self.dtype = inference_config.dtype | ||||||
|  |         self.max_batch_size = inference_config.max_batch_size | ||||||
|  |  | ||||||
|  |         # initialize cache | ||||||
|  |         self._init_cache(model_config) | ||||||
|  |  | ||||||
|  |         # initialize batch | ||||||
|  |         torch.cuda.current_device() | ||||||
|  |         kv_max_split_num = ( | ||||||
|  |             inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 | ||||||
|  |         ) // inference_config.block_size | ||||||
|  |         head_dim = model_config.hidden_size // model_config.num_attention_heads | ||||||
|  |  | ||||||
|  |         # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, | ||||||
|  |         # which may cause bugs and this issue should be fixed later. | ||||||
|  |         self.running_bb = BatchBucket( | ||||||
|  |             num_heads=model_config.num_attention_heads // inference_config.tp_size, | ||||||
|  |             head_dim=head_dim, | ||||||
|  |             max_batch_size=self.max_batch_size, | ||||||
|  |             max_length=inference_config.max_input_len + inference_config.max_output_len, | ||||||
|  |             block_size=inference_config.block_size, | ||||||
|  |             kv_max_split_num=kv_max_split_num, | ||||||
|  |             fd_interm_tensor=None, | ||||||
|  |             dtype=self.dtype, | ||||||
|  |         ) | ||||||
|  |         self.prefill_bb = BatchBucket( | ||||||
|  |             num_heads=model_config.num_attention_heads // inference_config.tp_size, | ||||||
|  |             head_dim=head_dim, | ||||||
|  |             max_batch_size=self.max_batch_size, | ||||||
|  |             max_length=inference_config.max_input_len + inference_config.max_output_len, | ||||||
|  |             block_size=inference_config.block_size, | ||||||
|  |             kv_max_split_num=kv_max_split_num, | ||||||
|  |             fd_interm_tensor=None, | ||||||
|  |             dtype=self.dtype, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def _init_cache(self, model_config): | ||||||
|  |         self.cache_manager = RPCKVCacheManager(self.inference_config, model_config) | ||||||
|   | |||||||
							
								
								
									
										291
									
								
								colossalai/inference/core/rpc_engine.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										291
									
								
								colossalai/inference/core/rpc_engine.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,291 @@ | |||||||
|  | import asyncio | ||||||
|  | from itertools import count | ||||||
|  | from time import sleep | ||||||
|  | from typing import List, Tuple, Union | ||||||
|  |  | ||||||
|  | import rpyc | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from rpyc.utils.server import ThreadedServer | ||||||
|  | from torch import multiprocessing as mp | ||||||
|  | from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast | ||||||
|  | from transformers.configuration_utils import PretrainedConfig | ||||||
|  |  | ||||||
|  | from colossalai.inference.batch_bucket import BatchBucket | ||||||
|  | from colossalai.inference.config import InferenceConfig, InputMetaData | ||||||
|  | from colossalai.inference.executor.rpc_worker import rpcWorkerService | ||||||
|  | from colossalai.inference.utils import find_available_ports | ||||||
|  | from colossalai.logging import get_dist_logger | ||||||
|  | from colossalai.shardformer.policies.base_policy import Policy | ||||||
|  |  | ||||||
|  | from .engine import InferenceEngine | ||||||
|  | from .request_handler import RPCRequestHandler | ||||||
|  |  | ||||||
|  | __all__ = ["RPCInferenceEngine"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_server(host, port, event: mp.Event = None): | ||||||
|  |     server = ThreadedServer( | ||||||
|  |         rpcWorkerService, port=port, protocol_config={"allow_public_attrs": True, "allow_all_attrs": True} | ||||||
|  |     ) | ||||||
|  |     if event: | ||||||
|  |         event.set() | ||||||
|  |     server.start() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RPCInferenceEngine(InferenceEngine): | ||||||
|  |  | ||||||
|  |     """ | ||||||
|  |     InferenceEngine which manages the inference process.. | ||||||
|  |  | ||||||
|  |     NOTE This `RPCInferenceEngine` is designed for multiple-card/online serving. | ||||||
|  |     Original `InferenceEngine` is designed for single card and offline service, though it supports multi-card offline inference. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         model_or_path (nn.Module or str): Path or nn.Module of this model, Currently we don't support `nn.Module` Format | ||||||
|  |         tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use. | ||||||
|  |         inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference. | ||||||
|  |         verbose (bool): Determine whether or not to log the generation process. | ||||||
|  |         model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_or_path: Union[nn.Module, str], | ||||||
|  |         tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], | ||||||
|  |         inference_config: InferenceConfig, | ||||||
|  |         verbose: bool = False, | ||||||
|  |         model_policy: Policy = None, | ||||||
|  |     ) -> None: | ||||||
|  |         """ | ||||||
|  |         If you input a real model loaded by transformers, the init will take quite a long time | ||||||
|  |         Currently we don't support model(nn.Module) format as the param. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         torch.multiprocessing.set_start_method("spawn", force=True) | ||||||
|  |  | ||||||
|  |         self.inference_config = inference_config | ||||||
|  |         self.tokenizer = tokenizer | ||||||
|  |         self.tokenizer.pad_token = self.tokenizer.eos_token | ||||||
|  |  | ||||||
|  |         self.verbose = verbose | ||||||
|  |         self.logger = get_dist_logger(__name__) | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             if isinstance(model_or_path, str): | ||||||
|  |                 self.model_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) | ||||||
|  |             elif isinstance(model_or_path, nn.Module): | ||||||
|  |                 self.logger.error( | ||||||
|  |                     f"An exception occurred during loading model Config: For {__class__.__name__}, we don't support param like nn.Module currently\n" | ||||||
|  |                 ) | ||||||
|  |                 # self.model_config = model_or_path.config | ||||||
|  |             else: | ||||||
|  |                 self.logger.error( | ||||||
|  |                     f"An exception occurred during loading model Config: Please pass right param for {__class__.__name__}\n" | ||||||
|  |                 ) | ||||||
|  |         except Exception as e: | ||||||
|  |             self.logger.error( | ||||||
|  |                 f"An exception occurred during loading model Config: {e}, The path should be transformers-like\n" | ||||||
|  |             ) | ||||||
|  |         self.generation_config = inference_config.to_generation_config(self.model_config) | ||||||
|  |  | ||||||
|  |         self.tp_size = inference_config.tp_size | ||||||
|  |         self.events = [mp.Event() for _ in range(self.tp_size)] | ||||||
|  |  | ||||||
|  |         # This operation will init the dist env and models | ||||||
|  |         self.workers: List[rpcWorkerService] = [] | ||||||
|  |         self.init_workers() | ||||||
|  |  | ||||||
|  |         asyncio.run(self.init_model(model_or_path, model_policy)) | ||||||
|  |  | ||||||
|  |         # init the scheduler and logic block manager | ||||||
|  |         self.request_handler = self.init_scheduler(self.inference_config, self.model_config) | ||||||
|  |  | ||||||
|  |         # init the physical cache | ||||||
|  |         alloc_shape = self.request_handler.cache_manager.get_physical_cache_shape() | ||||||
|  |         self.init_device_cache(alloc_shape) | ||||||
|  |  | ||||||
|  |         self.use_cuda_graph = self.inference_config.use_cuda_graph | ||||||
|  |         self.high_precision = inference_config.high_precision | ||||||
|  |         self.dtype = inference_config.dtype | ||||||
|  |  | ||||||
|  |         # Model and relatable attrs of speculative decoding will be set by `enable_spec_dec` | ||||||
|  |         self.use_spec_dec = False | ||||||
|  |         self.drafter_model = None | ||||||
|  |         self.drafter = None | ||||||
|  |         self.use_glide = False | ||||||
|  |         self.n_spec_tokens = self.inference_config.max_n_spec_tokens | ||||||
|  |  | ||||||
|  |         self.counter = count() | ||||||
|  |         self._verify_args() | ||||||
|  |  | ||||||
|  |         self.logger.info("engine init over ") | ||||||
|  |  | ||||||
|  |     def _verify_args(self) -> None: | ||||||
|  |         """Verify the input args""" | ||||||
|  |         if not isinstance(self.inference_config, InferenceConfig): | ||||||
|  |             raise TypeError("Invalid type of inference config provided.") | ||||||
|  |         if not isinstance(self.tokenizer, (PreTrainedTokenizerFast, PreTrainedTokenizer)): | ||||||
|  |             raise TypeError( | ||||||
|  |                 f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def init_workers(self): | ||||||
|  |         rpc_ports = find_available_ports(self.tp_size) | ||||||
|  |         self.worker_processes = [] | ||||||
|  |         # mp.set_start_method('spawn') | ||||||
|  |         for event, rpc_port in zip(self.events, rpc_ports): | ||||||
|  |             p = mp.Process(target=run_server, args=("localhost", rpc_port, event)) | ||||||
|  |             p.start() | ||||||
|  |             self.worker_processes.append(p) | ||||||
|  |             self.logger.info(f"Starting RPC Worker on localhost:{rpc_port}...") | ||||||
|  |  | ||||||
|  |         # Wait for all servers to start | ||||||
|  |         for event in self.events: | ||||||
|  |             event.wait() | ||||||
|  |             event.clear() | ||||||
|  |  | ||||||
|  |         sleep(0.05) | ||||||
|  |  | ||||||
|  |         self.logger.info(f"init rpc server done.") | ||||||
|  |  | ||||||
|  |         for rpc_port in rpc_ports: | ||||||
|  |             try: | ||||||
|  |                 conn = rpyc.connect( | ||||||
|  |                     "localhost", | ||||||
|  |                     rpc_port, | ||||||
|  |                     config={"allow_pickle": True, "allow_public_attrs": True, "allow_all_attrs": True}, | ||||||
|  |                 ) | ||||||
|  |                 self.workers.append(conn.root) | ||||||
|  |             except: | ||||||
|  |                 raise Exception("conn error!") | ||||||
|  |         self.logger.info(f"Build RPC Connection Success! Begin to load model...") | ||||||
|  |         asyncio.run(self.init_worker_env()) | ||||||
|  |         self.logger.info(f"init dist env over") | ||||||
|  |  | ||||||
|  |     async def async_parallel_wrapper(self, f, *args, **kwargs): | ||||||
|  |         async_res = rpyc.async_(f)(*args, **kwargs) | ||||||
|  |         await asyncio.to_thread(async_res.wait) | ||||||
|  |         assert async_res.ready | ||||||
|  |         return async_res.value | ||||||
|  |  | ||||||
|  |     async def init_worker_env(self): | ||||||
|  |         assert len(self.workers) == self.tp_size, "init workers first" | ||||||
|  |  | ||||||
|  |         dist_group_port = find_available_ports(1)[0] | ||||||
|  |         init_tasks = [ | ||||||
|  |             self.async_parallel_wrapper( | ||||||
|  |                 worker.init_dist_env, rank, self.inference_config.tp_size, "127.0.0.1", dist_group_port | ||||||
|  |             ) | ||||||
|  |             for rank, worker in enumerate(self.workers) | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         await asyncio.gather(*init_tasks) | ||||||
|  |  | ||||||
|  |     async def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): | ||||||
|  |         assert len(self.workers) == self.tp_size, "init workers first" | ||||||
|  |  | ||||||
|  |         inference_config_param = self.inference_config.to_rpc_param() | ||||||
|  |         model_path = model_or_path | ||||||
|  |         model_policy_param = model_policy.to_rpc_param() if model_policy else None | ||||||
|  |  | ||||||
|  |         init_tasks = [ | ||||||
|  |             self.async_parallel_wrapper(worker.init_model, inference_config_param, model_path, model_policy_param) | ||||||
|  |             for rank, worker in enumerate(self.workers) | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         await asyncio.gather(*init_tasks) | ||||||
|  |  | ||||||
|  |     def init_scheduler(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> RPCRequestHandler: | ||||||
|  |         return RPCRequestHandler(inference_config, model_config) | ||||||
|  |  | ||||||
|  |     async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]): | ||||||
|  |         assert len(self.workers) == self.tp_size, "init workers first" | ||||||
|  |  | ||||||
|  |         init_tasks = [self.async_parallel_wrapper(worker.init_cache, alloc_shape) for worker in self.workers] | ||||||
|  |  | ||||||
|  |         await asyncio.gather(*init_tasks) | ||||||
|  |  | ||||||
|  |     def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): | ||||||
|  |         asyncio.run(self._init_device_cache(alloc_shape)) | ||||||
|  |  | ||||||
|  |     def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: | ||||||
|  |         input_ids = batch.get_1D_inputs() | ||||||
|  |         sequence_lengths = batch.get_sequence_lengths() | ||||||
|  |  | ||||||
|  |         if batch.is_prompts: | ||||||
|  |             n_tokens = sequence_lengths.sum().item() | ||||||
|  |         else: | ||||||
|  |             n_tokens = batch.current_batch_size | ||||||
|  |             if batch.use_spec_dec: | ||||||
|  |                 n_tokens = batch.num_tokens_to_verify + 1 | ||||||
|  |                 assert n_tokens == input_ids.size(0) | ||||||
|  |                 n_tokens = n_tokens * batch.current_batch_size | ||||||
|  |  | ||||||
|  |         batch_token_ids = None | ||||||
|  |         config_dict = self.generation_config.to_dict() | ||||||
|  |         # process repetition_penalty, no_repeat_ngram_size | ||||||
|  |         for type in ["repetition_penalty", "no_repeat_ngram_size"]: | ||||||
|  |             if type in config_dict and config_dict[type] is not None: | ||||||
|  |                 batch_token_ids = batch.batch_token_ids | ||||||
|  |  | ||||||
|  |         # only when we have the graph for specific decoding batch size can we use the cuda graph for inference | ||||||
|  |         use_cuda_graph = False | ||||||
|  |         if self.use_cuda_graph and not batch.is_prompts and batch.current_batch_size in self.graph_runners.keys(): | ||||||
|  |             use_cuda_graph = True | ||||||
|  |  | ||||||
|  |         input_meta_data = InputMetaData( | ||||||
|  |             block_tables=batch.get_block_table_tensor(), | ||||||
|  |             sequence_lengths=sequence_lengths, | ||||||
|  |             fd_inter_tensor=None, | ||||||
|  |             batch_size=batch.current_batch_size, | ||||||
|  |             is_prompts=batch.is_prompts, | ||||||
|  |             use_cuda_kernel=self.inference_config.use_cuda_kernel, | ||||||
|  |             use_cuda_graph=use_cuda_graph, | ||||||
|  |             high_precision=self.high_precision, | ||||||
|  |             kv_seq_len=sequence_lengths.max().item(), | ||||||
|  |             head_dim=batch.head_dim, | ||||||
|  |             dtype=batch.dtype, | ||||||
|  |             use_spec_dec=batch.use_spec_dec, | ||||||
|  |             num_tokens_to_verify=batch.num_tokens_to_verify, | ||||||
|  |             batch_token_ids=batch_token_ids, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         return input_ids.tolist(), input_meta_data | ||||||
|  |  | ||||||
|  |     async def step_(self, input_token_ids, input_meta_data: InputMetaData): | ||||||
|  |         assert len(self.workers) == self.tp_size, "init workers first" | ||||||
|  |  | ||||||
|  |         init_tasks = [ | ||||||
|  |             self.async_parallel_wrapper(worker.execute_model_forward, input_token_ids, input_meta_data.to_rpc_param()) | ||||||
|  |             for worker in self.workers | ||||||
|  |         ] | ||||||
|  |         ret = await asyncio.gather(*init_tasks) | ||||||
|  |  | ||||||
|  |         return ret[0] | ||||||
|  |  | ||||||
|  |     def step(self) -> List[str]: | ||||||
|  |         batch = self.request_handler.schedule() | ||||||
|  |  | ||||||
|  |         input_token_ids, input_meta_data = self.prepare_input(batch) | ||||||
|  |         # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. | ||||||
|  |         next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data)) | ||||||
|  |  | ||||||
|  |         # update the request_handler | ||||||
|  |         next_tokens = torch.tensor(next_tokens, dtype=torch.int) | ||||||
|  |         self.request_handler.append_next_tokens(next_tokens) | ||||||
|  |         finished_sequences = self.request_handler.update() | ||||||
|  |         return finished_sequences | ||||||
|  |  | ||||||
|  |     def kill_workers(self): | ||||||
|  |         """ | ||||||
|  |         I don't find a good way to implicit invoke self.kill_workers | ||||||
|  |         """ | ||||||
|  |         assert len(self.workers) != 0 | ||||||
|  |         for proc in self.worker_processes: | ||||||
|  |             proc.kill() | ||||||
|  |             proc.join() | ||||||
|  |         self.logger.info(f"worker killed, serving end") | ||||||
|  |  | ||||||
|  |     def __del__(self): | ||||||
|  |         self.kill_workers() | ||||||
							
								
								
									
										300
									
								
								colossalai/inference/executor/rpc_worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										300
									
								
								colossalai/inference/executor/rpc_worker.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,300 @@ | |||||||
|  | import os | ||||||
|  | from typing import List, Tuple, Union | ||||||
|  |  | ||||||
|  | import rpyc | ||||||
|  | import torch | ||||||
|  | import torch.distributed as dist | ||||||
|  | from torch import nn | ||||||
|  | from transformers import AutoConfig, AutoModelForCausalLM | ||||||
|  | from transformers.models.llama.modeling_llama import LlamaForCausalLM | ||||||
|  |  | ||||||
|  | import colossalai | ||||||
|  | from colossalai.accelerator import get_accelerator | ||||||
|  | from colossalai.cluster import ProcessGroupMesh | ||||||
|  | from colossalai.inference.config import InferenceConfig, InputMetaData | ||||||
|  | from colossalai.inference.flash_decoding_utils import FDIntermTensors | ||||||
|  | from colossalai.inference.modeling.policy import ( | ||||||
|  |     NoPaddingBaichuanModelInferPolicy, | ||||||
|  |     NoPaddingLlamaModelInferPolicy, | ||||||
|  |     model_policy_map, | ||||||
|  | ) | ||||||
|  | from colossalai.inference.sampler import search_tokens | ||||||
|  | from colossalai.inference.utils import get_model_size, has_index_file | ||||||
|  | from colossalai.interface import ModelWrapper | ||||||
|  | from colossalai.logging import get_dist_logger | ||||||
|  | from colossalai.pipeline.stage_manager import PipelineStageManager | ||||||
|  | from colossalai.shardformer import ShardConfig, ShardFormer | ||||||
|  | from colossalai.shardformer.policies.base_policy import Policy | ||||||
|  |  | ||||||
|  | PP_AXIS, TP_AXIS = 0, 1 | ||||||
|  |  | ||||||
|  | _SUPPORTED_MODELS = { | ||||||
|  |     "LlamaForCausalLM": LlamaForCausalLM, | ||||||
|  |     "BaichuanForCausalLM": AutoModelForCausalLM, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | _SUPPORTED_MODEL_POLICIES = { | ||||||
|  |     "NoPaddingLlamaModelInferPolicy": NoPaddingLlamaModelInferPolicy, | ||||||
|  |     "NoPaddingBaichuanModelInferPolicy": NoPaddingBaichuanModelInferPolicy, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | logger = get_dist_logger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class rpcWorkerService(rpyc.Service): | ||||||
|  |  | ||||||
|  |     """ | ||||||
|  |     Execute the computation tasks and manage its own kv cache | ||||||
|  |  | ||||||
|  |     Func with prefix `exposed_` will be invoked by client. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def exposed_init_dist_env(self, rank, world_size, master_address, master_port): | ||||||
|  |         logger.info(f"init process group for rank {rank}") | ||||||
|  |         colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) | ||||||
|  |         logger.info(f"init process group done for rank {rank}") | ||||||
|  |  | ||||||
|  |     def exposed_init_model( | ||||||
|  |         self, inference_config_param: dict, model_or_path: Union[nn.Module, str], model_policy_param: str = None | ||||||
|  |     ): | ||||||
|  |         assert dist.is_initialized(), "invoke init_dist_env first please!" | ||||||
|  |  | ||||||
|  |         self.inference_config = InferenceConfig.from_rpc_param(inference_config_param) | ||||||
|  |         model_policy = _SUPPORTED_MODEL_POLICIES[model_policy_param]() if model_policy_param else None | ||||||
|  |  | ||||||
|  |         self.dtype = self.inference_config.dtype | ||||||
|  |         self.verbose = True | ||||||
|  |  | ||||||
|  |         self._init_model(model_or_path, model_policy) | ||||||
|  |         self._init_fd_tensor() | ||||||
|  |         self._init_output_tensor() | ||||||
|  |         logger.info(f"init model done for rank {dist.get_rank()}") | ||||||
|  |  | ||||||
|  |     def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): | ||||||
|  |         """Initialize the physical cache on the device. | ||||||
|  |  | ||||||
|  |         For each layer of the model, we allocate two tensors for key and value respectively, | ||||||
|  |         with shape of [num_blocks, num_kv_heads, block_size, head_size] | ||||||
|  |         """ | ||||||
|  |         kalloc_shape, valloc_shape = alloc_shape | ||||||
|  |         num_layers = self.model_config.num_hidden_layers | ||||||
|  |  | ||||||
|  |         self.k_cache: List[torch.Tensor] = [] | ||||||
|  |         self.v_cache: List[torch.Tensor] = [] | ||||||
|  |         for _ in range(num_layers): | ||||||
|  |             self.k_cache.append( | ||||||
|  |                 torch.zeros( | ||||||
|  |                     kalloc_shape, | ||||||
|  |                     dtype=self.inference_config.kv_cache_dtype, | ||||||
|  |                     device=get_accelerator().get_current_device(), | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |             self.v_cache.append( | ||||||
|  |                 torch.zeros( | ||||||
|  |                     valloc_shape, | ||||||
|  |                     dtype=self.inference_config.kv_cache_dtype, | ||||||
|  |                     device=get_accelerator().get_current_device(), | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         logger.info("physical cache init over") | ||||||
|  |  | ||||||
|  |     def exposed_execute_model_forward(self, input_token_ids_param: List[int], input_meta_data_param: dict): | ||||||
|  |         # prepare the data for model forward | ||||||
|  |         input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) | ||||||
|  |         input_meta_data.fd_inter_tensor = self.fd_inter_tensor | ||||||
|  |         if input_meta_data.is_prompts: | ||||||
|  |             n_tokens = input_meta_data.sequence_lengths.sum().item() | ||||||
|  |         else: | ||||||
|  |             n_tokens = input_meta_data.batch_size | ||||||
|  |         input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) | ||||||
|  |  | ||||||
|  |         # execute the model | ||||||
|  |         logits = self.model( | ||||||
|  |             input_token_ids, | ||||||
|  |             self.output_tensor[:n_tokens], | ||||||
|  |             input_meta_data, | ||||||
|  |             self.k_cache, | ||||||
|  |             self.v_cache, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # sampler | ||||||
|  |         if self.inference_config.pad_input: | ||||||
|  |             logits = logits[:, -1, :] | ||||||
|  |         next_tokens = search_tokens( | ||||||
|  |             self.inference_config.to_generation_config(self.model_config), | ||||||
|  |             logits, | ||||||
|  |             input_meta_data.is_prompts, | ||||||
|  |             input_meta_data.batch_token_ids, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # return the tokens generated to scheduler | ||||||
|  |         return next_tokens.tolist() | ||||||
|  |  | ||||||
|  |     def _init_output_tensor(self): | ||||||
|  |         alloc_shape = ( | ||||||
|  |             self.inference_config.max_batch_size | ||||||
|  |             * (self.inference_config.max_input_len + self.inference_config.max_output_len), | ||||||
|  |             self.model_config.hidden_size // self.inference_config.tp_size, | ||||||
|  |         ) | ||||||
|  |         self.output_tensor = torch.zeros(alloc_shape, dtype=self.dtype, device=self.device) | ||||||
|  |  | ||||||
|  |     def _init_fd_tensor(self): | ||||||
|  |         fd_inter_tensor = FDIntermTensors() | ||||||
|  |  | ||||||
|  |         if fd_inter_tensor._tensors_initialized: | ||||||
|  |             fd_inter_tensor._reset() | ||||||
|  |  | ||||||
|  |         # For Spec-Dec, process the speculated tokens plus the token in the last step for each seq | ||||||
|  |         max_n_tokens = self.inference_config.max_batch_size | ||||||
|  |         max_n_tokens *= self.inference_config.max_n_spec_tokens + 1 | ||||||
|  |  | ||||||
|  |         inference_config = self.inference_config | ||||||
|  |         kv_max_split_num = ( | ||||||
|  |             inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 | ||||||
|  |         ) // inference_config.block_size | ||||||
|  |         head_dim = self.model_config.hidden_size // self.model_config.num_attention_heads | ||||||
|  |  | ||||||
|  |         fd_inter_tensor.initialize( | ||||||
|  |             max_batch_size=max_n_tokens, | ||||||
|  |             num_attn_heads=self.model_config.num_attention_heads // self.inference_config.tp_size, | ||||||
|  |             kv_max_split_num=kv_max_split_num, | ||||||
|  |             head_dim=head_dim, | ||||||
|  |             dtype=self.dtype, | ||||||
|  |             device=get_accelerator().get_current_device(), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.fd_inter_tensor = fd_inter_tensor | ||||||
|  |  | ||||||
|  |     def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): | ||||||
|  |         """ | ||||||
|  |         Shard model or/and Load weight | ||||||
|  |  | ||||||
|  |         Shard model: When we set tp_size > 1, we will shard the model by given model_policy. | ||||||
|  |         Load Weight: If we pass a local model path, we will load the model weight by checkpoint_io. If it is a remote-transformer url, we will use `AutoModel.from_pretrained` api of transformers lib | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format. | ||||||
|  |             model_policy (Policy): the policy to replace the model | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         if isinstance(model_or_path, str): | ||||||
|  |             is_local = os.path.isdir(model_or_path) | ||||||
|  |             try: | ||||||
|  |                 hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True) | ||||||
|  |                 arch = getattr(hf_config, "architectures")[0] | ||||||
|  |                 if is_local: | ||||||
|  |                     model = _SUPPORTED_MODELS[arch](hf_config) | ||||||
|  |                 else: | ||||||
|  |                     # load the real checkpoint | ||||||
|  |                     model = _SUPPORTED_MODELS[arch].from_pretrained(model_or_path, trust_remote_code=True) | ||||||
|  |             except Exception as e: | ||||||
|  |                 logger.error( | ||||||
|  |                     f"An exception occurred during loading model: {e}, model should be loaded by transformers\n" | ||||||
|  |                 ) | ||||||
|  |         else: | ||||||
|  |             model = model_or_path | ||||||
|  |  | ||||||
|  |         self.model_config = model.config | ||||||
|  |  | ||||||
|  |         torch.cuda.empty_cache() | ||||||
|  |         init_gpu_memory = torch.cuda.mem_get_info()[0] | ||||||
|  |  | ||||||
|  |         self.device = get_accelerator().get_current_device() | ||||||
|  |         torch.cuda.set_device(self.device) | ||||||
|  |         if self.verbose: | ||||||
|  |             logger.info(f"the device is {self.device}") | ||||||
|  |  | ||||||
|  |         model = model.to(dtype=self.dtype, non_blocking=False).eval() | ||||||
|  |  | ||||||
|  |         if self.verbose: | ||||||
|  |             logger.info( | ||||||
|  |                 f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if model_policy is None: | ||||||
|  |             if self.inference_config.pad_input: | ||||||
|  |                 model_type = "padding_" + self.model_config.model_type | ||||||
|  |             else: | ||||||
|  |                 model_type = "nopadding_" + self.model_config.model_type | ||||||
|  |             model_policy = model_policy_map[model_type]() | ||||||
|  |  | ||||||
|  |         pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) | ||||||
|  |         tp_group = pg_mesh.get_group_along_axis(TP_AXIS) | ||||||
|  |  | ||||||
|  |         self.model = self._shardformer( | ||||||
|  |             model, | ||||||
|  |             model_policy, | ||||||
|  |             None, | ||||||
|  |             tp_group=tp_group, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         self.model = ModelWrapper(model).to(device=get_accelerator().get_current_device()) | ||||||
|  |  | ||||||
|  |         if self.verbose: | ||||||
|  |             logger.info( | ||||||
|  |                 f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if isinstance(model_or_path, str) and is_local: | ||||||
|  |             from colossalai.inference.core.plugin import InferCheckpoint_io | ||||||
|  |  | ||||||
|  |             cpt_io = InferCheckpoint_io() | ||||||
|  |             if_has_index_file, model_index_file = has_index_file(model_or_path) | ||||||
|  |             assert if_has_index_file, "the model path is invalid" | ||||||
|  |             cpt_io.load_model(self.model, model_index_file) | ||||||
|  |  | ||||||
|  |         free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() | ||||||
|  |         peak_memory = init_gpu_memory - free_gpu_memory | ||||||
|  |         if self.verbose: | ||||||
|  |             logger.info( | ||||||
|  |                 f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def _shardformer( | ||||||
|  |         self, | ||||||
|  |         model: nn.Module, | ||||||
|  |         model_policy: Policy, | ||||||
|  |         stage_manager: PipelineStageManager = None, | ||||||
|  |         tp_group: ProcessGroupMesh = None, | ||||||
|  |     ) -> nn.Module: | ||||||
|  |         """ | ||||||
|  |         Initialize ShardConfig and replace the model with shardformer. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             model (nn.Module): Path or nn.Module of this model. | ||||||
|  |             model_policy (Policy): The policy to shardformer model which is determined by the model type. | ||||||
|  |             stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None. | ||||||
|  |             tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None. | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             nn.Module: The model optimized by Shardformer. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         shardconfig = ShardConfig( | ||||||
|  |             tensor_parallel_process_group=tp_group, | ||||||
|  |             pipeline_stage_manager=stage_manager, | ||||||
|  |             enable_tensor_parallelism=(self.inference_config.tp_size > 1), | ||||||
|  |             enable_fused_normalization=False, | ||||||
|  |             enable_all_optimization=False, | ||||||
|  |             enable_flash_attention=False, | ||||||
|  |             enable_jit_fused=False, | ||||||
|  |             enable_sequence_parallelism=False, | ||||||
|  |         ) | ||||||
|  |         shardformer = ShardFormer(shard_config=shardconfig) | ||||||
|  |         shard_model, _ = shardformer.optimize(model, model_policy) | ||||||
|  |         return shard_model | ||||||
|  |  | ||||||
|  |     def exposed_compute_only_for_test(self): | ||||||
|  |         dist_rank = dist.get_rank() | ||||||
|  |  | ||||||
|  |         # Dummy data for each worker | ||||||
|  |         data = torch.tensor([dist_rank], dtype=torch.float).cuda(dist_rank) | ||||||
|  |         dist.barrier() | ||||||
|  |  | ||||||
|  |         # Perform distributed all_reduce | ||||||
|  |         dist.all_reduce(data, op=dist.ReduceOp.SUM) | ||||||
|  |  | ||||||
|  |         dist.barrier() | ||||||
|  |         logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}") | ||||||
|  |  | ||||||
|  |         return data.item() | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| from .block_cache import CacheBlock | from .block_cache import CacheBlock | ||||||
| from .kvcache_manager import KVCacheManager | from .kvcache_manager import KVCacheManager, RPCKVCacheManager | ||||||
|  |  | ||||||
| __all__ = ["CacheBlock", "KVCacheManager"] | __all__ = ["CacheBlock", "KVCacheManager", "RPCKVCacheManager"] | ||||||
|   | |||||||
| @@ -497,3 +497,80 @@ class KVCacheManager: | |||||||
|             k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device)) |             k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device)) | ||||||
|             v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device)) |             v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device)) | ||||||
|         return k_cache, v_cache |         return k_cache, v_cache | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RPCKVCacheManager(KVCacheManager): | ||||||
|  |     def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None: | ||||||
|  |         self.logger = get_dist_logger(__name__) | ||||||
|  |         self.device = get_current_device() | ||||||
|  |         self.config = config | ||||||
|  |  | ||||||
|  |         # Parallel settings | ||||||
|  |         self.tp_size = config.tp_size | ||||||
|  |         # Model settings | ||||||
|  |         self.dtype = config.dtype | ||||||
|  |         self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() | ||||||
|  |         self.num_layers = model_config.num_hidden_layers | ||||||
|  |         self.head_num = model_config.num_attention_heads | ||||||
|  |         self.head_size = model_config.hidden_size // self.head_num | ||||||
|  |         if hasattr(model_config, "num_key_value_heads"): | ||||||
|  |             self.kv_head_num = model_config.num_key_value_heads | ||||||
|  |         else: | ||||||
|  |             self.kv_head_num = self.head_num | ||||||
|  |  | ||||||
|  |         if config.kv_cache_dtype is None: | ||||||
|  |             self.kv_cache_dtype = config.dtype | ||||||
|  |         else: | ||||||
|  |             self.kv_cache_dtype = config.kv_cache_dtype | ||||||
|  |  | ||||||
|  |         assert ( | ||||||
|  |             self.kv_head_num % self.tp_size == 0 | ||||||
|  |         ), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}" | ||||||
|  |         self.kv_head_num //= self.tp_size | ||||||
|  |         self.beam_width = config.beam_width | ||||||
|  |         self.max_batch_size = config.max_batch_size | ||||||
|  |         self.max_input_length = config.max_input_len | ||||||
|  |         self.max_output_length = config.max_output_len | ||||||
|  |         # Cache block settings | ||||||
|  |         self.block_size = config.block_size | ||||||
|  |         # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size | ||||||
|  |         self.max_blocks_per_sequence = ( | ||||||
|  |             self.max_input_length + self.max_output_length + self.block_size - 1 | ||||||
|  |         ) // self.block_size | ||||||
|  |         self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width | ||||||
|  |  | ||||||
|  |         # Logical cache blocks allocation | ||||||
|  |         self._available_blocks = self.num_blocks | ||||||
|  |         self._cache_blocks = tuple(self._init_logical_caches()) | ||||||
|  |         # block availablity state 0->allocated, 1->free | ||||||
|  |         self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool) | ||||||
|  |         self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64) | ||||||
|  |         self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64) | ||||||
|  |  | ||||||
|  |     def get_physical_cache_shape(self) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: | ||||||
|  |         # Physical cache allocation | ||||||
|  |         if self.config.use_cuda_kernel: | ||||||
|  |             x = 16 // torch.tensor([], dtype=self.config.dtype).element_size() | ||||||
|  |             kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x) | ||||||
|  |             valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) | ||||||
|  |             self.logger.info( | ||||||
|  |                 f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks." | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size) | ||||||
|  |             kalloc_shape = alloc_shape | ||||||
|  |             valloc_shape = alloc_shape | ||||||
|  |             self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.") | ||||||
|  |         return kalloc_shape, valloc_shape | ||||||
|  |  | ||||||
|  |     def get_kv_cache(self): | ||||||
|  |         """Get k_cache and v_cache""" | ||||||
|  |         return NotImplementedError | ||||||
|  |  | ||||||
|  |     def _init_logical_caches(self): | ||||||
|  |         """Initialize the logical cache blocks.""" | ||||||
|  |         blocks = [] | ||||||
|  |         for i in range(self.num_blocks): | ||||||
|  |             cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs=None, v_ptrs=None) | ||||||
|  |             blocks.append(cache_block) | ||||||
|  |         return blocks | ||||||
|   | |||||||
| @@ -1,10 +1,9 @@ | |||||||
| # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py | # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py | ||||||
|  | from typing import List | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
|  |  | ||||||
| from colossalai.inference.batch_bucket import BatchBucket |  | ||||||
|  |  | ||||||
| _LOGIT_PROCESSOR_MAP = {} | _LOGIT_PROCESSOR_MAP = {} | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -22,7 +21,7 @@ def register_logit_processor(process_type): | |||||||
|  |  | ||||||
|  |  | ||||||
| @register_logit_processor("no_repeat_ngram_size") | @register_logit_processor("no_repeat_ngram_size") | ||||||
| def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket): | def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch_token_ids: List[List[int]]): | ||||||
|     """ |     """ | ||||||
|     enforces no repetition of n-grams to avoid repetitions of word sequences. |     enforces no repetition of n-grams to avoid repetitions of word sequences. | ||||||
|     """ |     """ | ||||||
| @@ -31,7 +30,6 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck | |||||||
|         raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.") |         raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.") | ||||||
|  |  | ||||||
|     if ngram_size != 0: |     if ngram_size != 0: | ||||||
|         batch_token_ids = batch.batch_token_ids |  | ||||||
|         batch_size = len(batch_token_ids) |         batch_size = len(batch_token_ids) | ||||||
|  |  | ||||||
|         for batch_id in range(batch_size): |         for batch_id in range(batch_size): | ||||||
| @@ -55,7 +53,7 @@ def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBuck | |||||||
|  |  | ||||||
|  |  | ||||||
| @register_logit_processor("repetition_penalty") | @register_logit_processor("repetition_penalty") | ||||||
| def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket): | def repetition_penalty_logit_process(logits, penalty: float, batch_token_ids: List[List[int]]): | ||||||
|     """ |     """ | ||||||
|     apply the penalty to the tokens present in the prompt. |     apply the penalty to the tokens present in the prompt. | ||||||
|     """ |     """ | ||||||
| @@ -67,7 +65,6 @@ def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket) | |||||||
|  |  | ||||||
|     # TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels. |     # TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels. | ||||||
|     if penalty != 1.0: |     if penalty != 1.0: | ||||||
|         batch_token_ids = batch.batch_token_ids |  | ||||||
|         for batch_id in range(len(batch_token_ids)): |         for batch_id in range(len(batch_token_ids)): | ||||||
|             current_logit = logits[batch_id] |             current_logit = logits[batch_id] | ||||||
|             current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device) |             current_token = torch.tensor(batch_token_ids[batch_id], dtype=torch.long, device=logits.device) | ||||||
|   | |||||||
| @@ -1,3 +1,4 @@ | |||||||
|  | from colossalai.inference.config import RPC_PARAM | ||||||
| from colossalai.inference.modeling.layers.baichuan_tp_linear import ( | from colossalai.inference.modeling.layers.baichuan_tp_linear import ( | ||||||
|     BaichuanLMHeadLinear1D_Col, |     BaichuanLMHeadLinear1D_Col, | ||||||
|     BaichuanWpackLinear1D_Col, |     BaichuanWpackLinear1D_Col, | ||||||
| @@ -18,7 +19,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, | |||||||
| from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy | from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy | ||||||
|  |  | ||||||
|  |  | ||||||
| class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy): | class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): | ||||||
|     def __init__(self) -> None: |     def __init__(self) -> None: | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |  | ||||||
| @@ -100,3 +101,10 @@ class NoPaddingBaichuanModelInferPolicy(LlamaForCausalLMPolicy): | |||||||
|     def postprocess(self): |     def postprocess(self): | ||||||
|         init_to_get_rotary(self.model.model) |         init_to_get_rotary(self.model.model) | ||||||
|         return self.model |         return self.model | ||||||
|  |  | ||||||
|  |     def to_rpc_param(self) -> str: | ||||||
|  |         return __class__.__name__ | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def from_rpc_param() -> "NoPaddingBaichuanModelInferPolicy": | ||||||
|  |         return NoPaddingBaichuanModelInferPolicy() | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm | from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm | ||||||
|  |  | ||||||
|  | from colossalai.inference.config import RPC_PARAM | ||||||
| from colossalai.inference.modeling.models.nopadding_llama import ( | from colossalai.inference.modeling.models.nopadding_llama import ( | ||||||
|     NopadLlamaAttention, |     NopadLlamaAttention, | ||||||
|     NopadLlamaMLP, |     NopadLlamaMLP, | ||||||
| @@ -14,7 +15,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, | |||||||
| from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy | from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy | ||||||
|  |  | ||||||
|  |  | ||||||
| class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): | class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy, RPC_PARAM): | ||||||
|     def __init__(self) -> None: |     def __init__(self) -> None: | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |  | ||||||
| @@ -102,3 +103,10 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): | |||||||
|     def postprocess(self): |     def postprocess(self): | ||||||
|         init_to_get_rotary(self.model.model, self.model.config.rope_theta) |         init_to_get_rotary(self.model.model, self.model.config.rope_theta) | ||||||
|         return self.model |         return self.model | ||||||
|  |  | ||||||
|  |     def to_rpc_param(self) -> str: | ||||||
|  |         return __class__.__name__ | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def from_rpc_param() -> "NoPaddingLlamaModelInferPolicy": | ||||||
|  |         return NoPaddingLlamaModelInferPolicy() | ||||||
|   | |||||||
| @@ -1,6 +1,9 @@ | |||||||
| from typing import List, Tuple | from typing import List, Optional, Tuple | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
|  | from transformers.generation import GenerationConfig | ||||||
|  |  | ||||||
|  | from colossalai.inference.logit_processors import logit_processor | ||||||
|  |  | ||||||
|  |  | ||||||
| def greedy_sample( | def greedy_sample( | ||||||
| @@ -59,3 +62,47 @@ def beam_search_sample( | |||||||
|  |  | ||||||
|     results.append((next_token_ids, parent_ids)) |     results.append((next_token_ids, parent_ids)) | ||||||
|     return results |     return results | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _sample(probs: torch.Tensor, logprobs: torch.Tensor, generation_config: GenerationConfig, is_prompt: bool = False): | ||||||
|  |     if generation_config.num_beams == 1: | ||||||
|  |         if generation_config.do_sample: | ||||||
|  |             sample_tokens = multinomial_sample(generation_config, probs) | ||||||
|  |         else: | ||||||
|  |             sample_tokens = greedy_sample(generation_config, logprobs) | ||||||
|  |     else: | ||||||
|  |         sample_tokens = beam_search_sample(generation_config, logprobs, is_prompt=is_prompt) | ||||||
|  |  | ||||||
|  |     return sample_tokens | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def search_tokens( | ||||||
|  |     generation_config: GenerationConfig, | ||||||
|  |     logits, | ||||||
|  |     is_prompt: bool = False, | ||||||
|  |     batch_token_ids: Optional[List[List[int]]] = None, | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     Sample tokens for finished requests. | ||||||
|  |     """ | ||||||
|  |     # NOTE: need to decide the granularity to process logits (sequence or batch) | ||||||
|  |     config_dict = generation_config.to_dict() | ||||||
|  |     # process repetition_penalty, no_repeat_ngram_size | ||||||
|  |     for type in ["repetition_penalty", "no_repeat_ngram_size"]: | ||||||
|  |         if type in config_dict and config_dict[type] is not None: | ||||||
|  |             logits = logit_processor(type, logits, config_dict[type], batch_token_ids) | ||||||
|  |  | ||||||
|  |     # do logit processor | ||||||
|  |     if generation_config.do_sample: | ||||||
|  |         # process temperature, top_k, top_p | ||||||
|  |         for type in ["temperature", "top_k", "top_p"]: | ||||||
|  |             if type in config_dict and config_dict[type] is not None: | ||||||
|  |                 logits = logit_processor(type, logits, config_dict[type]) | ||||||
|  |  | ||||||
|  |     # calculate probs | ||||||
|  |     probs = torch.softmax(logits, dim=-1, dtype=torch.float) | ||||||
|  |     logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) | ||||||
|  |  | ||||||
|  |     # sample the next tokens | ||||||
|  |     sample_tokens = _sample(probs, logprobs, generation_config, is_prompt) | ||||||
|  |     return sample_tokens | ||||||
|   | |||||||
| @@ -9,6 +9,8 @@ from typing import Optional, Tuple | |||||||
| import torch | import torch | ||||||
| from torch import nn | from torch import nn | ||||||
|  |  | ||||||
|  | from colossalai.testing import free_port | ||||||
|  |  | ||||||
|  |  | ||||||
| def init_to_get_rotary(self, base=10000, use_elem=False): | def init_to_get_rotary(self, base=10000, use_elem=False): | ||||||
|     """ |     """ | ||||||
| @@ -102,3 +104,12 @@ def get_model_size(model: nn.Module): | |||||||
|     for key, param in model.named_parameters(): |     for key, param in model.named_parameters(): | ||||||
|         total_size += param.element_size() * param.numel() |         total_size += param.element_size() * param.numel() | ||||||
|     return total_size / (1024**3) |     return total_size / (1024**3) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def find_available_ports(num: int): | ||||||
|  |     try: | ||||||
|  |         free_ports = [free_port() for i in range(num)] | ||||||
|  |     except OSError as e: | ||||||
|  |         print(f"An OS error occurred: {e}") | ||||||
|  |         raise RuntimeError("Error finding available ports") | ||||||
|  |     return free_ports | ||||||
|   | |||||||
| @@ -19,4 +19,5 @@ datasets | |||||||
| pydantic | pydantic | ||||||
| ray | ray | ||||||
| peft>=0.7.1 | peft>=0.7.1 | ||||||
|  | rpyc==6.0.0 | ||||||
| #auto-gptq now not support torch1.12 | #auto-gptq now not support torch1.12 | ||||||
|   | |||||||
| @@ -19,3 +19,4 @@ protobuf | |||||||
| transformers==4.36.2 | transformers==4.36.2 | ||||||
| peft>=0.7.1 | peft>=0.7.1 | ||||||
| bitsandbytes>=0.39.0 | bitsandbytes>=0.39.0 | ||||||
|  | rpyc==6.0.0 | ||||||
|   | |||||||
							
								
								
									
										105
									
								
								tests/test_infer/test_rpc_engine.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								tests/test_infer/test_rpc_engine.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | |||||||
|  | import random | ||||||
|  |  | ||||||
|  | import numpy as np | ||||||
|  | import pytest | ||||||
|  | import torch | ||||||
|  | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | ||||||
|  |  | ||||||
|  | from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig | ||||||
|  | from colossalai.inference.core.rpc_engine import RPCInferenceEngine | ||||||
|  | from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy | ||||||
|  | from colossalai.testing import parameterize, rerun_if_address_is_in_use | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def setup_seed(seed): | ||||||
|  |     torch.manual_seed(seed) | ||||||
|  |     torch.random.manual_seed(seed) | ||||||
|  |     torch.cuda.manual_seed_all(seed) | ||||||
|  |     np.random.seed(seed) | ||||||
|  |     random.seed(seed) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_inference_engine(tp_size, use_engine=False, prompt_template=None, do_sample=True, policy=None): | ||||||
|  |     setup_seed(20) | ||||||
|  |     tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") | ||||||
|  |     model = "meta-llama/Llama-2-7b-hf"  # remote mode path | ||||||
|  |     inputs = [ | ||||||
|  |         "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", | ||||||
|  |         "介绍一下武汉,", | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     output_len = 38 | ||||||
|  |     top_p = 0.5 | ||||||
|  |     top_k = 50 | ||||||
|  |  | ||||||
|  |     if use_engine: | ||||||
|  |         inference_config = InferenceConfig( | ||||||
|  |             max_output_len=output_len, | ||||||
|  |             prompt_template=prompt_template, | ||||||
|  |             dtype="fp32", | ||||||
|  |             use_cuda_kernel=True, | ||||||
|  |             tp_size=tp_size, | ||||||
|  |         ) | ||||||
|  |         inference_engine = RPCInferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy) | ||||||
|  |         assert inference_engine.generation_config.max_new_tokens == output_len | ||||||
|  |         inference_engine.add_request(prompts=inputs) | ||||||
|  |         assert inference_engine.request_handler._has_waiting() | ||||||
|  |         generation_config = GenerationConfig( | ||||||
|  |             max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k | ||||||
|  |         ) | ||||||
|  |         outputs = inference_engine.generate(generation_config=generation_config) | ||||||
|  |     else: | ||||||
|  |         if prompt_template: | ||||||
|  |             # apply prompt template | ||||||
|  |             inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs] | ||||||
|  |         model = AutoModelForCausalLM.from_pretrained(model).cuda() | ||||||
|  |         tokenizer.pad_token = tokenizer.eos_token | ||||||
|  |         tokenizer.pad_token_id = tokenizer.eos_token_id | ||||||
|  |         inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"] | ||||||
|  |         inputs = inputs.cuda() | ||||||
|  |         generation_config = GenerationConfig( | ||||||
|  |             do_sample=do_sample, | ||||||
|  |             dtype="fp32", | ||||||
|  |             top_p=top_p, | ||||||
|  |             top_k=top_k, | ||||||
|  |             pad_token_id=tokenizer.pad_token_id, | ||||||
|  |             max_new_tokens=output_len, | ||||||
|  |         ) | ||||||
|  |         outputs = model.generate(inputs, generation_config=generation_config) | ||||||
|  |         outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||||||
|  |  | ||||||
|  |     return outputs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_engine(tp_size, **kwargs): | ||||||
|  |     return check_inference_engine(tp_size=tp_size, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.largedist | ||||||
|  | @parameterize("prompt_template", [None, "llama"]) | ||||||
|  | @parameterize("do_sample", [False]) | ||||||
|  | @rerun_if_address_is_in_use() | ||||||
|  | def test_tp_engine(prompt_template, do_sample): | ||||||
|  |     if torch.multiprocessing.get_start_method(allow_none=True) is None: | ||||||
|  |         torch.multiprocessing.set_start_method("spawn") | ||||||
|  |     kwargs1 = { | ||||||
|  |         "use_engine": True, | ||||||
|  |         "prompt_template": prompt_template, | ||||||
|  |         "do_sample": do_sample, | ||||||
|  |         "policy": NoPaddingLlamaModelInferPolicy(), | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None} | ||||||
|  |  | ||||||
|  |     colossal_tp_1_output = run_engine(1, **kwargs1) | ||||||
|  |     colossal_tp_2_output = run_engine(2, **kwargs1) | ||||||
|  |     transformer_tp_1_output = run_engine(1, **kwargs2) | ||||||
|  |  | ||||||
|  |     for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): | ||||||
|  |         assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" | ||||||
|  |         assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     torch.multiprocessing.set_start_method("spawn")  # this code will not be ok for settings to fork to subprocess | ||||||
|  |     test_tp_engine() | ||||||
		Reference in New Issue
	
	Block a user