Fix Sagemaker Batch Endpoints (#3249)

Add different typing for @evandiewald 's heplful PR

---------

Co-authored-by: Evan Diewald <evandiewald@gmail.com>
This commit is contained in:
Zander Chase
2023-04-22 08:49:51 -07:00
committed by GitHub
parent 7e79f8c136
commit 61d40ba042
3 changed files with 48 additions and 29 deletions

View File

@@ -9,7 +9,15 @@
"\n",
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
"\n",
"For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
"For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). **Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
"\n",
"Change from\n",
"\n",
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
"\n",
"to:\n",
"\n",
"`return {\"vectors\": sentence_embeddings.tolist()}`."
]
},
{
@@ -29,7 +37,7 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict\n",
"from typing import Dict, List\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
"import json\n",
@@ -39,13 +47,13 @@
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
" return input_str.encode('utf-8')\n",
" \n",
" def transform_output(self, output: bytes) -> str:\n",
"\n",
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[\"embeddings\"]\n",
" return response_json[\"vectors\"]\n",
"\n",
"content_handler = ContentHandler()\n",
"\n",