mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-31 16:08:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			128 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			128 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | |
|  "cells": [
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "1f83f273",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "# SageMaker Endpoint Embeddings\n",
 | |
|     "\n",
 | |
|     "Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
 | |
|     "\n",
 | |
|     "For instrucstions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "88d366bd",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "!pip3 install langchain boto3"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 3,
 | |
|    "id": "1e9b926a",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "from typing import Dict\n",
 | |
|     "from langchain.embeddings import SagemakerEndpointEmbeddings\n",
 | |
|     "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
 | |
|     "import json\n",
 | |
|     "\n",
 | |
|     "\n",
 | |
|     "class ContentHandler(ContentHandlerBase):\n",
 | |
|     "    content_type = \"application/json\"\n",
 | |
|     "    accepts = \"application/json\"\n",
 | |
|     "\n",
 | |
|     "    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
 | |
|     "        input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
 | |
|     "        return input_str.encode('utf-8')\n",
 | |
|     "    \n",
 | |
|     "    def transform_output(self, output: bytes) -> str:\n",
 | |
|     "        response_json = json.loads(output.read().decode(\"utf-8\"))\n",
 | |
|     "        return response_json[\"embeddings\"]\n",
 | |
|     "\n",
 | |
|     "content_handler = ContentHandler()\n",
 | |
|     "\n",
 | |
|     "\n",
 | |
|     "embeddings = SagemakerEndpointEmbeddings(\n",
 | |
|     "    # endpoint_name=\"endpoint-name\", \n",
 | |
|     "    # credentials_profile_name=\"credentials-profile-name\", \n",
 | |
|     "    endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\", \n",
 | |
|     "    region_name=\"us-east-1\", \n",
 | |
|     "    content_handler=content_handler\n",
 | |
|     ")"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "fe9797b8",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "query_result = embeddings.embed_query(\"foo\")"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 6,
 | |
|    "id": "76f1b752",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "doc_results = embeddings.embed_documents([\"foo\"])"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "fff99b21",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "doc_results"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": null,
 | |
|    "id": "aaad49f8",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": []
 | |
|   }
 | |
|  ],
 | |
|  "metadata": {
 | |
|   "kernelspec": {
 | |
|    "display_name": "Python 3 (ipykernel)",
 | |
|    "language": "python",
 | |
|    "name": "python3"
 | |
|   },
 | |
|   "language_info": {
 | |
|    "codemirror_mode": {
 | |
|     "name": "ipython",
 | |
|     "version": 3
 | |
|    },
 | |
|    "file_extension": ".py",
 | |
|    "mimetype": "text/x-python",
 | |
|    "name": "python",
 | |
|    "nbconvert_exporter": "python",
 | |
|    "pygments_lexer": "ipython3",
 | |
|    "version": "3.9.1"
 | |
|   },
 | |
|   "vscode": {
 | |
|    "interpreter": {
 | |
|     "hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
 | |
|    }
 | |
|   }
 | |
|  },
 | |
|  "nbformat": 4,
 | |
|  "nbformat_minor": 5
 | |
| }
 |