mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-06 05:25:04 +00:00
Fix Sagemaker Batch Endpoints (#3249)
Add different typing for @evandiewald 's heplful PR --------- Co-authored-by: Evan Diewald <evandiewald@gmail.com>
This commit is contained in:
@@ -9,7 +9,15 @@
|
||||
"\n",
|
||||
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
|
||||
"\n",
|
||||
"For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
|
||||
"For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). **Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
|
||||
"\n",
|
||||
"Change from\n",
|
||||
"\n",
|
||||
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
|
||||
"\n",
|
||||
"to:\n",
|
||||
"\n",
|
||||
"`return {\"vectors\": sentence_embeddings.tolist()}`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -29,7 +37,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Dict\n",
|
||||
"from typing import Dict, List\n",
|
||||
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
|
||||
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
|
||||
"import json\n",
|
||||
@@ -39,13 +47,13 @@
|
||||
" content_type = \"application/json\"\n",
|
||||
" accepts = \"application/json\"\n",
|
||||
"\n",
|
||||
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
||||
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
|
||||
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
|
||||
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
|
||||
" return input_str.encode('utf-8')\n",
|
||||
" \n",
|
||||
" def transform_output(self, output: bytes) -> str:\n",
|
||||
"\n",
|
||||
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
|
||||
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
||||
" return response_json[\"embeddings\"]\n",
|
||||
" return response_json[\"vectors\"]\n",
|
||||
"\n",
|
||||
"content_handler = ContentHandler()\n",
|
||||
"\n",
|
||||
|
Reference in New Issue
Block a user