[Feat]Tensor Model Parallel Support For Inference (#5563)

* tensor parallel support naive source

* [fix]precision, model load and refactor the framework

* add tp unit test

* docstring

* fix do_sample
This commit is contained in:
Runyu Lu
2024-04-18 16:56:46 +08:00
committed by GitHub
parent be396ad6cc
commit e37ee2fb65
8 changed files with 640 additions and 150 deletions

View File

@@ -140,7 +140,7 @@ class RequestHandler:
fd_inter_tensor.initialize(
max_batch_size=max_n_tokens,
num_attn_heads=model_config.num_attention_heads,
num_attn_heads=model_config.num_attention_heads // inference_config.tp_size,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
dtype=self.dtype,
@@ -150,7 +150,7 @@ class RequestHandler:
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
@@ -161,7 +161,7 @@ class RequestHandler:
device=device,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,