mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Feat]Tensor Model Parallel Support For Inference (#5563)
* tensor parallel support naive source * [fix]precision, model load and refactor the framework * add tp unit test * docstring * fix do_sample
This commit is contained in:
@@ -140,7 +140,7 @@ class RequestHandler:
|
||||
|
||||
fd_inter_tensor.initialize(
|
||||
max_batch_size=max_n_tokens,
|
||||
num_attn_heads=model_config.num_attention_heads,
|
||||
num_attn_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
@@ -150,7 +150,7 @@ class RequestHandler:
|
||||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||
# which may cause bugs and this issue should be fixed later.
|
||||
self.running_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads,
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
@@ -161,7 +161,7 @@ class RequestHandler:
|
||||
device=device,
|
||||
)
|
||||
self.prefill_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads,
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
|
Reference in New Issue
Block a user