mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Docs: enrich SageMaker endpoint embeddings with docstrings and examples (#9924)
Description: added comments to address the relationship between input/output transformations and the customised inference.py script.
This commit is contained in:
parent
8dbf4cbe80
commit
0fb95ebe66
@ -48,10 +48,31 @@
|
|||||||
" accepts = \"application/json\"\n",
|
" accepts = \"application/json\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
|
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
|
||||||
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
|
" \"\"\"\n",
|
||||||
|
" Transforms the input into bytes that can be consumed by SageMaker endpoint.\n",
|
||||||
|
" Args:\n",
|
||||||
|
" inputs: List of input strings.\n",
|
||||||
|
" model_kwargs: Additional keyword arguments to be passed to the endpoint.\n",
|
||||||
|
" Returns:\n",
|
||||||
|
" The transformed bytes input.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # Example: inference.py expects a JSON string with a \"inputs\" key:\n",
|
||||||
|
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs}) \n",
|
||||||
" return input_str.encode(\"utf-8\")\n",
|
" return input_str.encode(\"utf-8\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
|
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Transforms the bytes output from the endpoint into a list of embeddings.\n",
|
||||||
|
" Args:\n",
|
||||||
|
" output: The bytes output from SageMaker endpoint.\n",
|
||||||
|
" Returns:\n",
|
||||||
|
" The transformed output - list of embeddings\n",
|
||||||
|
" Note:\n",
|
||||||
|
" The length of the outer list is the number of input strings.\n",
|
||||||
|
" The length of the inner lists is the embedding dimension.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # Example: inference.py returns a JSON string with the list of\n",
|
||||||
|
" # embeddings in a \"vectors\" key:\n",
|
||||||
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
||||||
" return response_json[\"vectors\"]\n",
|
" return response_json[\"vectors\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -60,7 +81,6 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"embeddings = SagemakerEndpointEmbeddings(\n",
|
"embeddings = SagemakerEndpointEmbeddings(\n",
|
||||||
" # endpoint_name=\"endpoint-name\",\n",
|
|
||||||
" # credentials_profile_name=\"credentials-profile-name\",\n",
|
" # credentials_profile_name=\"credentials-profile-name\",\n",
|
||||||
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
|
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
|
||||||
" region_name=\"us-east-1\",\n",
|
" region_name=\"us-east-1\",\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user