[Feat]Tensor Model Parallel Support For Inference (#5563)

* tensor parallel support naive source

* [fix]precision, model load and refactor the framework

* add tp unit test

* docstring

* fix do_sample
This commit is contained in:
Runyu Lu
2024-04-18 16:56:46 +08:00
committed by GitHub
parent be396ad6cc
commit e37ee2fb65
8 changed files with 640 additions and 150 deletions

View File

@@ -40,7 +40,7 @@ def check_inference_engine(use_cuda_graph=False, batch_size=32):
input_len = 1024
output_len = 128
do_sample = True
do_sample = False
top_p = 0.5
top_k = 50