mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-21 23:17:48 +00:00
# What does this PR do? This PR adds similar to `llms` a SageMaker-powered `embeddings` class. This is helpful if you want to leverage Hugging Face models on SageMaker for creating your indexes. I added a example into the [docs/modules/indexes/examples/embeddings.ipynb](https://github.com/hwchase17/langchain/compare/master...philschmid:add-sm-embeddings?expand=1#diff-e82629e2894974ec87856aedd769d4bdfe400314b03734f32bee5990bc7e8062) document. The example currently includes some `_### TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging Face model ###_ ` code showing how you can deploy a sentence-transformers to SageMaker and then run the methods of the embeddings class. @hwchase17 please let me know if/when i should remove the `_### TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging Face model ###_` in the description i linked to a detail blog on how to deploy a Sentence Transformers so i think we don't need to include those steps here. I also reused the `ContentHandlerBase` from `langchain.llms.sagemaker_endpoint` and changed the output type to `any` since it is depending on the implementation.
2135 lines
64 KiB
Plaintext
2135 lines
64 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "249b4058",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Embeddings\n",
|
|
"\n",
|
|
"This notebook goes over how to use the Embedding class in LangChain.\n",
|
|
"\n",
|
|
"The Embedding class is a class designed for interfacing with embeddings. There are lots of Embedding providers (OpenAI, Cohere, Hugging Face, etc) - this class is designed to provide a standard interface for all of them.\n",
|
|
"\n",
|
|
"Embeddings create a vector representation of a piece of text. This is useful because it means we can think about text in the vector space, and do things like semantic search where we look for pieces of text that are most similar in the vector space.\n",
|
|
"\n",
|
|
"The base Embedding class in LangChain exposes two methods: `embed_documents` and `embed_query`. The largest difference is that these two methods have different interfaces: one works over multiple documents, while the other works over a single document. Besides this, another reason for having these as two separate methods is that some embedding providers have different embedding methods for documents (to be searched over) vs queries (the search query itself)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "278b6c63",
|
|
"metadata": {},
|
|
"source": [
|
|
"## OpenAI\n",
|
|
"\n",
|
|
"Let's load the OpenAI Embedding class."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "0be1af71",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.embeddings import OpenAIEmbeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "2c66e5da",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"embeddings = OpenAIEmbeddings()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "01370375",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = \"This is a test document.\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "bfb6142c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "0356c3b7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_result = embeddings.embed_documents([text])"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "bb61bbeb",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's load the OpenAI Embedding class with first generation models (e.g. text-search-ada-doc-001/text-search-ada-query-001). Note: These are not recommended models - see [here](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c0b072cc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.embeddings.openai import OpenAIEmbeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a56b70f5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"embeddings = OpenAIEmbeddings(model_name=\"ada\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "14aefb64",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = \"This is a test document.\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3c39ed33",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e3221db6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_result = embeddings.embed_documents([text])"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "c3852491",
|
|
"metadata": {},
|
|
"source": [
|
|
"## AzureOpenAI\n",
|
|
"\n",
|
|
"Let's load the OpenAI Embedding class with environment variables set to indicate to use Azure endpoints."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1b40f827",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# set the environment variables needed for openai package to know to reach out to azure\n",
|
|
"import os\n",
|
|
"\n",
|
|
"os.environ[\"OPENAI_API_TYPE\"] = \"azure\"\n",
|
|
"os.environ[\"OPENAI_API_BASE\"] = \"https://<your-endpoint.openai.azure.com/\"\n",
|
|
"os.environ[\"OPENAI_API_KEY\"] = \"your AzureOpenAI key\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bb36d16c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"embeddings = OpenAIEmbeddings(model=\"your-embeddings-deployment-name\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "228abcbb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = \"This is a test document.\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "60dd7fad",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "83bc1a72",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_result = embeddings.embed_documents([text])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "42f76e43",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Cohere\n",
|
|
"\n",
|
|
"Let's load the Cohere Embedding class."
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "ca9e2b3a",
|
|
"metadata": {},
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "6b82f59f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.embeddings import CohereEmbeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "26895c60",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"embeddings = CohereEmbeddings(cohere_api_key=cohere_api_key)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "eea52814",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = \"This is a test document.\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "fbe167bf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "38ad3b20",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_result = embeddings.embed_documents([text])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ed47bb62",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Hugging Face Hub\n",
|
|
"Let's load the Hugging Face Embedding class."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "861521a9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.embeddings import HuggingFaceEmbeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "ff9be586",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"embeddings = HuggingFaceEmbeddings()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "d0a98ae9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = \"This is a test document.\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "5d6c682b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "bb5e74c0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_result = embeddings.embed_documents([text])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fff4734f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## TensorflowHub\n",
|
|
"Let's load the TensorflowHub Embedding class."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "f822104b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.embeddings import TensorflowHubEmbeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "bac84e46",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2023-01-30 23:53:01.652176: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
|
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
|
"2023-01-30 23:53:34.362802: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
|
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"embeddings = TensorflowHubEmbeddings()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "4790d770",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = \"This is a test document.\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "f556dcdb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "59428e05",
|
|
"metadata": {},
|
|
"source": [
|
|
"## InstructEmbeddings\n",
|
|
"Let's load the HuggingFace instruct Embeddings class."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "92c5b61e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.embeddings import HuggingFaceInstructEmbeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "062547b9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"load INSTRUCTOR_Transformer\n",
|
|
"max_seq_length 512\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"embeddings = HuggingFaceInstructEmbeddings(\n",
|
|
" query_instruction=\"Represent the query for retrieval: \"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "e1dcc4bd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = \"This is a test document.\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "90f0db94",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "eec4efda",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Self Hosted Embeddings\n",
|
|
"Let's load the SelfHostedEmbeddings, SelfHostedHuggingFaceEmbeddings, and SelfHostedHuggingFaceInstructEmbeddings classes."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d338722a",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.embeddings import (\n",
|
|
" SelfHostedEmbeddings,\n",
|
|
" SelfHostedHuggingFaceEmbeddings,\n",
|
|
" SelfHostedHuggingFaceInstructEmbeddings,\n",
|
|
")\n",
|
|
"import runhouse as rh"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "146559e8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# For an on-demand A100 with GCP, Azure, or Lambda\n",
|
|
"gpu = rh.cluster(name=\"rh-a10x\", instance_type=\"A100:1\", use_spot=False)\n",
|
|
"\n",
|
|
"# For an on-demand A10G with AWS (no single A100s on AWS)\n",
|
|
"# gpu = rh.cluster(name='rh-a10x', instance_type='g5.2xlarge', provider='aws')\n",
|
|
"\n",
|
|
"# For an existing cluster\n",
|
|
"# gpu = rh.cluster(ips=['<ip of the cluster>'],\n",
|
|
"# ssh_creds={'ssh_user': '...', 'ssh_private_key':'<path_to_key>'},\n",
|
|
"# name='my-cluster')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1230f7df",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"embeddings = SelfHostedHuggingFaceEmbeddings(hardware=gpu)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "2684e928",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"text = \"This is a test document.\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1dc5e606",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cef9cc54",
|
|
"metadata": {},
|
|
"source": [
|
|
"And similarly for SelfHostedHuggingFaceInstructEmbeddings:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "81a17ca3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"embeddings = SelfHostedHuggingFaceInstructEmbeddings(hardware=gpu)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5a33d1c8",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now let's load an embedding model with a custom load function:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "c4af5679",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_pipeline():\n",
|
|
" from transformers import (\n",
|
|
" AutoModelForCausalLM,\n",
|
|
" AutoTokenizer,\n",
|
|
" pipeline,\n",
|
|
" ) # Must be inside the function in notebooks\n",
|
|
"\n",
|
|
" model_id = \"facebook/bart-base\"\n",
|
|
" tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
|
" model = AutoModelForCausalLM.from_pretrained(model_id)\n",
|
|
" return pipeline(\"feature-extraction\", model=model, tokenizer=tokenizer)\n",
|
|
"\n",
|
|
"\n",
|
|
"def inference_fn(pipeline, prompt):\n",
|
|
" # Return last hidden state of the model\n",
|
|
" if isinstance(prompt, list):\n",
|
|
" return [emb[0][-1] for emb in pipeline(prompt)]\n",
|
|
" return pipeline(prompt)[0][-1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8654334b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"embeddings = SelfHostedEmbeddings(\n",
|
|
" model_load_fn=get_pipeline,\n",
|
|
" hardware=gpu,\n",
|
|
" model_reqs=[\"./\", \"torch\", \"transformers\"],\n",
|
|
" inference_fn=inference_fn,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fc1bfd0f",
|
|
"metadata": {
|
|
"scrolled": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(text)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f9c02c78",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Fake Embeddings\n",
|
|
"\n",
|
|
"LangChain also provides a fake embedding class. You can use this to test your pipelines."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "2ffc2e4b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.embeddings import FakeEmbeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "80777571",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"embeddings = FakeEmbeddings(size=1352)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "3ec9d8f0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(\"foo\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "3b9ae9e1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_results = embeddings.embed_documents([\"foo\"])"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "1f83f273",
|
|
"metadata": {},
|
|
"source": [
|
|
"## SageMaker Endpoint Embeddings\n",
|
|
"\n",
|
|
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker learn more [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "88d366bd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip3 install langchain boto3"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "c5855922",
|
|
"metadata": {},
|
|
"source": [
|
|
"## _### TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging Face model ###_ "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "e0ddd9b4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install sagemaker --quiet\n",
|
|
"\n",
|
|
"import os \n",
|
|
"os.environ[\"AWS_DEFAULT_REGION\"] = \"us-east-1\"\n",
|
|
"import boto3\n",
|
|
"from sagemaker import Session\n",
|
|
"# get sagemaker execution role to deploy\n",
|
|
"iam = boto3.client('iam')\n",
|
|
"role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
|
|
"sess = Session()\n",
|
|
"# create code/ dir\n",
|
|
"os.makedirs(\"model/code\", exist_ok=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "86ce76c6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Writing model/code/inference.py\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"%%writefile model/code/inference.py\n",
|
|
"\n",
|
|
"from transformers import AutoTokenizer, AutoModel\n",
|
|
"import torch\n",
|
|
"import torch.nn.functional as F\n",
|
|
"\n",
|
|
"# Helper: Mean Pooling - Take attention mask into account for correct averaging\n",
|
|
"def mean_pooling(model_output, attention_mask):\n",
|
|
" token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n",
|
|
" input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
|
|
" return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
|
|
"\n",
|
|
"\n",
|
|
"def model_fn(model_dir):\n",
|
|
" # Load model from HuggingFace Hub\n",
|
|
" tokenizer = AutoTokenizer.from_pretrained(\"sentence-transformers/all-MiniLM-L6-v2\")\n",
|
|
" model = AutoModel.from_pretrained(\"sentence-transformers/all-MiniLM-L6-v2\")\n",
|
|
" return model, tokenizer\n",
|
|
"\n",
|
|
"def predict_fn(data, model_and_tokenizer):\n",
|
|
" # destruct model and tokenizer\n",
|
|
" model, tokenizer = model_and_tokenizer\n",
|
|
"\n",
|
|
" # Tokenize sentences\n",
|
|
" sentences = data.pop(\"inputs\", data)\n",
|
|
" encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')\n",
|
|
"\n",
|
|
" # Compute token embeddings\n",
|
|
" with torch.no_grad():\n",
|
|
" model_output = model(**encoded_input)\n",
|
|
"\n",
|
|
" # Perform pooling\n",
|
|
" sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n",
|
|
"\n",
|
|
" # Normalize embeddings\n",
|
|
" sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)\n",
|
|
"\n",
|
|
" # return dictonary, which will be json serializable\n",
|
|
" return {\"embeddings\": sentence_embeddings[0].tolist()}\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "24b809d4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"code/\n",
|
|
"code/inference.py\n",
|
|
"----!"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sagemaker.s3 import S3Uploader\n",
|
|
"from sagemaker.huggingface.model import HuggingFaceModel\n",
|
|
"\n",
|
|
"# create model.tar.gz and upload to s3 \n",
|
|
"parent_dir=os.getcwd()\n",
|
|
"# change to model dir\n",
|
|
"os.chdir(\"model\")\n",
|
|
"# use pigz for faster and parallel compression\n",
|
|
"!tar zcvf model.tar.gz *\n",
|
|
"# change back to parent dir\n",
|
|
"os.chdir(parent_dir)\n",
|
|
"\n",
|
|
"\n",
|
|
"# upload model.tar.gz to s3\n",
|
|
"s3_model_uri = S3Uploader.upload(local_path=\"model/model.tar.gz\", desired_s3_uri=f\"s3://{sess.default_bucket()}/embeddings\")\n",
|
|
"\n",
|
|
"# create Hugging Face Model Class\n",
|
|
"huggingface_model = HuggingFaceModel(\n",
|
|
" model_data=s3_model_uri, # path to your model and script\n",
|
|
" role=role, # iam role with permissions to create an Endpoint\n",
|
|
" transformers_version=\"4.26\", # transformers version used\n",
|
|
" pytorch_version=\"1.13\", # pytorch version used\n",
|
|
" py_version='py39', # python version used\n",
|
|
")\n",
|
|
"\n",
|
|
"# deploy the endpoint endpoint\n",
|
|
"predictor = huggingface_model.deploy(1,\"ml.m5.2xlarge\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "324213fd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'huggingface-pytorch-inference-2023-03-21-16-14-03-834'"
|
|
]
|
|
},
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"predictor.endpoint_name"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "3dff3efa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'embeddings': [-0.03833858296275139,\n",
|
|
" 0.12346473336219788,\n",
|
|
" -0.028642961755394936,\n",
|
|
" 0.05365271493792534,\n",
|
|
" 0.008845399133861065,\n",
|
|
" -0.039839327335357666,\n",
|
|
" -0.07300589978694916,\n",
|
|
" 0.04777129739522934,\n",
|
|
" -0.03046245686709881,\n",
|
|
" 0.054979756474494934,\n",
|
|
" 0.08505291491746902,\n",
|
|
" 0.03665667772293091,\n",
|
|
" -0.0053200023248791695,\n",
|
|
" -0.002233208389952779,\n",
|
|
" -0.06071101501584053,\n",
|
|
" -0.027237888425588608,\n",
|
|
" -0.011351668275892735,\n",
|
|
" -0.04243773967027664,\n",
|
|
" 0.009129947051405907,\n",
|
|
" 0.10081552714109421,\n",
|
|
" 0.075787253677845,\n",
|
|
" 0.06911724805831909,\n",
|
|
" 0.009857476688921452,\n",
|
|
" -0.0018377384403720498,\n",
|
|
" 0.02624901942908764,\n",
|
|
" 0.03290242329239845,\n",
|
|
" -0.07177436351776123,\n",
|
|
" 0.028384245932102203,\n",
|
|
" 0.06170952320098877,\n",
|
|
" -0.05252952501177788,\n",
|
|
" 0.033661700785160065,\n",
|
|
" 0.07446815073490143,\n",
|
|
" 0.07536035776138306,\n",
|
|
" 0.03538404032588005,\n",
|
|
" 0.0671340748667717,\n",
|
|
" 0.01079804077744484,\n",
|
|
" 0.08167019486427307,\n",
|
|
" 0.01656288281083107,\n",
|
|
" 0.03283063322305679,\n",
|
|
" 0.03632563352584839,\n",
|
|
" 0.002172857290133834,\n",
|
|
" -0.09895739704370499,\n",
|
|
" 0.005046747159212828,\n",
|
|
" 0.050896503031253815,\n",
|
|
" 0.009287566877901554,\n",
|
|
" 0.024507733061909676,\n",
|
|
" -0.0644078254699707,\n",
|
|
" 0.0019362837774679065,\n",
|
|
" -0.079103484749794,\n",
|
|
" 0.020850397646427155,\n",
|
|
" -0.01922827586531639,\n",
|
|
" -0.02805466018617153,\n",
|
|
" -0.07059794664382935,\n",
|
|
" -0.007083615753799677,\n",
|
|
" 0.01040570717304945,\n",
|
|
" 0.038834139704704285,\n",
|
|
" 0.01765601523220539,\n",
|
|
" -0.019606105983257294,\n",
|
|
" -0.020058417692780495,\n",
|
|
" 0.018083792179822922,\n",
|
|
" -0.00017212114471476525,\n",
|
|
" 0.013043343089520931,\n",
|
|
" -0.09337250143289566,\n",
|
|
" 0.08453577756881714,\n",
|
|
" 0.11705499142408371,\n",
|
|
" 0.057413410395383835,\n",
|
|
" -0.022439058870077133,\n",
|
|
" -0.03677624836564064,\n",
|
|
" -0.03434618189930916,\n",
|
|
" -0.06383830308914185,\n",
|
|
" -0.06846101582050323,\n",
|
|
" -0.005553076509386301,\n",
|
|
" 0.044378429651260376,\n",
|
|
" 0.016669290140271187,\n",
|
|
" 0.030911751091480255,\n",
|
|
" -0.01975969970226288,\n",
|
|
" -0.024855101481080055,\n",
|
|
" -0.05904391035437584,\n",
|
|
" 0.0945875272154808,\n",
|
|
" -0.06530515849590302,\n",
|
|
" -0.05597255751490593,\n",
|
|
" -0.03284724801778793,\n",
|
|
" 0.00811521615833044,\n",
|
|
" -0.002234684070572257,\n",
|
|
" 0.002023296197876334,\n",
|
|
" 0.07942128926515579,\n",
|
|
" 0.08518771082162857,\n",
|
|
" 0.007815245538949966,\n",
|
|
" -0.01374559011310339,\n",
|
|
" 0.031104223802685738,\n",
|
|
" 0.010080904699862003,\n",
|
|
" -0.03275560960173607,\n",
|
|
" 0.007714808918535709,\n",
|
|
" -0.006191879045218229,\n",
|
|
" -0.05613413453102112,\n",
|
|
" 0.004364899825304747,\n",
|
|
" -0.01403757743537426,\n",
|
|
" -0.039304714649915695,\n",
|
|
" 0.07822350412607193,\n",
|
|
" 0.07393720000982285,\n",
|
|
" 0.05619140341877937,\n",
|
|
" 0.003301335731521249,\n",
|
|
" 0.04155803844332695,\n",
|
|
" -0.010387539863586426,\n",
|
|
" -0.13272696733474731,\n",
|
|
" -0.10473112016916275,\n",
|
|
" 0.018451808020472527,\n",
|
|
" -0.07520624995231628,\n",
|
|
" 0.04954085499048233,\n",
|
|
" -0.028530888259410858,\n",
|
|
" -0.01358408946543932,\n",
|
|
" -0.037112679332494736,\n",
|
|
" -0.06756578385829926,\n",
|
|
" -0.019552525132894516,\n",
|
|
" -0.010211824439466,\n",
|
|
" -0.051934875547885895,\n",
|
|
" -0.05941231921315193,\n",
|
|
" 0.016754044219851494,\n",
|
|
" 0.04098018631339073,\n",
|
|
" 0.001522376318462193,\n",
|
|
" 0.08095283806324005,\n",
|
|
" 0.002651068614795804,\n",
|
|
" -0.03870720788836479,\n",
|
|
" -0.04703034833073616,\n",
|
|
" -0.05854427441954613,\n",
|
|
" -0.029478492215275764,\n",
|
|
" 0.03882651776075363,\n",
|
|
" -8.102625254868425e-33,\n",
|
|
" -0.012914206832647324,\n",
|
|
" -0.014458492398262024,\n",
|
|
" -0.022368784993886948,\n",
|
|
" 0.1056450605392456,\n",
|
|
" 0.0037274654023349285,\n",
|
|
" 0.005939559079706669,\n",
|
|
" -0.023657256737351418,\n",
|
|
" 0.041163913905620575,\n",
|
|
" -0.07411694526672363,\n",
|
|
" 0.007076926529407501,\n",
|
|
" 0.0018349214224144816,\n",
|
|
" -0.03314222767949104,\n",
|
|
" 0.006818821653723717,\n",
|
|
" 0.04693515598773956,\n",
|
|
" -0.03836120665073395,\n",
|
|
" 0.05861291661858559,\n",
|
|
" -0.0840379074215889,\n",
|
|
" 0.11954139918088913,\n",
|
|
" -0.025204092264175415,\n",
|
|
" 0.02761165052652359,\n",
|
|
" 0.0244757030159235,\n",
|
|
" 0.014137371443212032,\n",
|
|
" 0.0128665491938591,\n",
|
|
" -0.05779572203755379,\n",
|
|
" -0.031691741198301315,\n",
|
|
" -0.0029006320983171463,\n",
|
|
" -0.027254171669483185,\n",
|
|
" -0.027451230213046074,\n",
|
|
" -0.03404244780540466,\n",
|
|
" 0.020136823877692223,\n",
|
|
" 0.022654512897133827,\n",
|
|
" 0.030933434143662453,\n",
|
|
" -0.045505885034799576,\n",
|
|
" -0.0025163793470710516,\n",
|
|
" 0.01510235108435154,\n",
|
|
" 0.09668111801147461,\n",
|
|
" 0.001809411682188511,\n",
|
|
" -0.05403870716691017,\n",
|
|
" 0.0025403527542948723,\n",
|
|
" 0.006051000207662582,\n",
|
|
" -0.056302234530448914,\n",
|
|
" -0.028254246339201927,\n",
|
|
" 0.06966646015644073,\n",
|
|
" 0.04410792514681816,\n",
|
|
" 0.039832279086112976,\n",
|
|
" -0.0419430211186409,\n",
|
|
" -0.0038099137600511312,\n",
|
|
" -0.04156690835952759,\n",
|
|
" 0.09482309967279434,\n",
|
|
" 0.019028929993510246,\n",
|
|
" -0.04011702537536621,\n",
|
|
" 0.0324222669005394,\n",
|
|
" 0.012565849348902702,\n",
|
|
" -0.056325893849134445,\n",
|
|
" 0.04461190849542618,\n",
|
|
" 0.04928917437791824,\n",
|
|
" 0.017442630603909492,\n",
|
|
" 0.05323149263858795,\n",
|
|
" -0.020876457914710045,\n",
|
|
" 0.061462536454200745,\n",
|
|
" -0.014837260358035564,\n",
|
|
" 0.07423629611730576,\n",
|
|
" -0.0576944537460804,\n",
|
|
" 0.049852192401885986,\n",
|
|
" -0.05890402942895889,\n",
|
|
" -0.0006539729074575007,\n",
|
|
" -0.10970547795295715,\n",
|
|
" -0.06829895824193954,\n",
|
|
" 0.13056595623493195,\n",
|
|
" -0.011906635947525501,\n",
|
|
" -0.0159984789788723,\n",
|
|
" -0.0211041159927845,\n",
|
|
" -0.007144191302359104,\n",
|
|
" -0.0164438858628273,\n",
|
|
" -0.016906214877963066,\n",
|
|
" -0.04813709110021591,\n",
|
|
" 0.015731733292341232,\n",
|
|
" 0.030654815956950188,\n",
|
|
" -0.004599860403686762,\n",
|
|
" -0.03823969140648842,\n",
|
|
" -0.04718682914972305,\n",
|
|
" -0.08068915456533432,\n",
|
|
" -0.011494779027998447,\n",
|
|
" -0.05190776288509369,\n",
|
|
" -0.04332379251718521,\n",
|
|
" -0.019109943881630898,\n",
|
|
" 0.036341868340969086,\n",
|
|
" -0.06575313955545425,\n",
|
|
" -0.014969361014664173,\n",
|
|
" -0.0911363959312439,\n",
|
|
" 0.035127948969602585,\n",
|
|
" 0.019904181361198425,\n",
|
|
" -0.055992890149354935,\n",
|
|
" -0.04273851588368416,\n",
|
|
" 0.11667020618915558,\n",
|
|
" 4.7537233992963164e-33,\n",
|
|
" -0.04277687147259712,\n",
|
|
" 0.010693217627704144,\n",
|
|
" -0.08699914813041687,\n",
|
|
" 0.11428382992744446,\n",
|
|
" 0.026194244623184204,\n",
|
|
" 0.008768039755523205,\n",
|
|
" 0.08940346539020538,\n",
|
|
" -0.0019060149788856506,\n",
|
|
" -0.0455072745680809,\n",
|
|
" 0.08432017266750336,\n",
|
|
" 0.011060485616326332,\n",
|
|
" 0.000260289350990206,\n",
|
|
" -0.00023178635456133634,\n",
|
|
" -0.0015942883910611272,\n",
|
|
" 0.0015580946346744895,\n",
|
|
" -0.025324126705527306,\n",
|
|
" -0.03786805272102356,\n",
|
|
" -0.0546313114464283,\n",
|
|
" 0.004270816687494516,\n",
|
|
" 0.016222011297941208,\n",
|
|
" -0.04763113334774971,\n",
|
|
" 0.11077607423067093,\n",
|
|
" 0.045782990753650665,\n",
|
|
" 0.07989457994699478,\n",
|
|
" -0.006792569998651743,\n",
|
|
" -0.010313649661839008,\n",
|
|
" 0.006975427269935608,\n",
|
|
" -0.09530742466449738,\n",
|
|
" -0.014356936328113079,\n",
|
|
" -0.013479162007570267,\n",
|
|
" -0.009381195530295372,\n",
|
|
" -0.0026153195649385452,\n",
|
|
" -0.12162390351295471,\n",
|
|
" 0.07765249162912369,\n",
|
|
" 0.009094372391700745,\n",
|
|
" -0.10183481127023697,\n",
|
|
" 0.13146239519119263,\n",
|
|
" -0.04587067291140556,\n",
|
|
" -0.009605005383491516,\n",
|
|
" 0.024302706122398376,\n",
|
|
" 0.045921340584754944,\n",
|
|
" 0.08771276473999023,\n",
|
|
" 0.055159058421850204,\n",
|
|
" 0.047116719186306,\n",
|
|
" -0.022800585255026817,\n",
|
|
" 0.05540422350168228,\n",
|
|
" 0.03942396119236946,\n",
|
|
" -0.06854791939258575,\n",
|
|
" 0.07696892321109772,\n",
|
|
" 0.0264807790517807,\n",
|
|
" 0.013421732001006603,\n",
|
|
" -0.03159027546644211,\n",
|
|
" 0.02122318185865879,\n",
|
|
" -0.02458374947309494,\n",
|
|
" -0.09490033239126205,\n",
|
|
" 0.05001789703965187,\n",
|
|
" -0.07885674387216568,\n",
|
|
" -0.0469261035323143,\n",
|
|
" -0.009405327029526234,\n",
|
|
" 0.06844945251941681,\n",
|
|
" -0.019532756879925728,\n",
|
|
" 0.08325397968292236,\n",
|
|
" -0.0020212731324136257,\n",
|
|
" 0.07861411571502686,\n",
|
|
" 0.009707036428153515,\n",
|
|
" -0.08329329639673233,\n",
|
|
" -0.08883728086948395,\n",
|
|
" 0.026159727945923805,\n",
|
|
" -0.0036121727898716927,\n",
|
|
" 0.0021212503779679537,\n",
|
|
" 0.06756487488746643,\n",
|
|
" -0.04351912811398506,\n",
|
|
" -0.031103378161787987,\n",
|
|
" -0.1055448055267334,\n",
|
|
" 0.08162888139486313,\n",
|
|
" -0.11693760007619858,\n",
|
|
" 0.0012153959833085537,\n",
|
|
" -0.042226288467645645,\n",
|
|
" -0.025040708482265472,\n",
|
|
" -0.05382077395915985,\n",
|
|
" 0.046688906848430634,\n",
|
|
" -0.004659516736865044,\n",
|
|
" -0.049144256860017776,\n",
|
|
" 0.05339549481868744,\n",
|
|
" -0.016824593767523766,\n",
|
|
" -0.018911045044660568,\n",
|
|
" 0.0021526776254177094,\n",
|
|
" 0.010545731522142887,\n",
|
|
" -0.02843359299004078,\n",
|
|
" 0.06319320946931839,\n",
|
|
" -0.041760899126529694,\n",
|
|
" 0.03648762032389641,\n",
|
|
" -0.028613677248358727,\n",
|
|
" 0.012441876344382763,\n",
|
|
" -0.030993392691016197,\n",
|
|
" -1.827941886745066e-08,\n",
|
|
" -0.03364746645092964,\n",
|
|
" -0.010457276366651058,\n",
|
|
" 0.006326176226139069,\n",
|
|
" -0.03394529968500137,\n",
|
|
" -0.03437081351876259,\n",
|
|
" 0.043725401163101196,\n",
|
|
" 0.07607871294021606,\n",
|
|
" -0.05076980963349342,\n",
|
|
" -0.06551552563905716,\n",
|
|
" -0.023710858076810837,\n",
|
|
" 0.05217289924621582,\n",
|
|
" 0.008229373954236507,\n",
|
|
" -0.05053586885333061,\n",
|
|
" -0.0046344115398824215,\n",
|
|
" 0.04596329480409622,\n",
|
|
" -0.048263613134622574,\n",
|
|
" -0.007646505255252123,\n",
|
|
" -0.0246701892465353,\n",
|
|
" -0.05899248272180557,\n",
|
|
" 0.02179579623043537,\n",
|
|
" -0.033197544515132904,\n",
|
|
" 0.026267115026712418,\n",
|
|
" 0.019565267488360405,\n",
|
|
" 0.022036483511328697,\n",
|
|
" -0.02707892283797264,\n",
|
|
" 0.07815380394458771,\n",
|
|
" 0.03259186074137688,\n",
|
|
" 0.10126295685768127,\n",
|
|
" 0.007166724652051926,\n",
|
|
" -0.031028350815176964,\n",
|
|
" 0.04080115631222725,\n",
|
|
" 0.10805943608283997,\n",
|
|
" -0.00941381324082613,\n",
|
|
" -0.01028114091604948,\n",
|
|
" 0.037279773503541946,\n",
|
|
" 0.11904413253068924,\n",
|
|
" 0.04982069879770279,\n",
|
|
" 0.05209505558013916,\n",
|
|
" 0.020246144384145737,\n",
|
|
" 0.05551902949810028,\n",
|
|
" -0.10270132124423981,\n",
|
|
" -0.009933318942785263,\n",
|
|
" -0.022510290145874023,\n",
|
|
" 0.03311152011156082,\n",
|
|
" 0.05227212607860565,\n",
|
|
" -0.029383286833763123,\n",
|
|
" -0.1383359581232071,\n",
|
|
" -0.014143865555524826,\n",
|
|
" -0.037659481167793274,\n",
|
|
" -0.08339183777570724,\n",
|
|
" -0.0034869578666985035,\n",
|
|
" -0.0415429063141346,\n",
|
|
" 0.04902830719947815,\n",
|
|
" 0.02155115082859993,\n",
|
|
" -0.040210600942373276,\n",
|
|
" 0.008557669818401337,\n",
|
|
" 0.046616844832897186,\n",
|
|
" -0.004114149138331413,\n",
|
|
" -0.03815949708223343,\n",
|
|
" -0.015223635360598564,\n",
|
|
" 0.12486445158720016,\n",
|
|
" 0.08800436556339264,\n",
|
|
" 0.08585748821496964,\n",
|
|
" -0.015338928438723087]}"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"predictor.predict({\"inputs\": \"This is a test document.\"})"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "3764f159",
|
|
"metadata": {},
|
|
"source": [
|
|
"## _### END TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging Face model ###_ "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "1e9b926a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Dict\n",
|
|
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
|
|
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
|
|
"import json\n",
|
|
"\n",
|
|
"\n",
|
|
"class ContentHandler(ContentHandlerBase):\n",
|
|
" content_type = \"application/json\"\n",
|
|
" accepts = \"application/json\"\n",
|
|
"\n",
|
|
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
|
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
|
|
" return input_str.encode('utf-8')\n",
|
|
" \n",
|
|
" def transform_output(self, output: bytes) -> str:\n",
|
|
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
|
" return response_json[\"embeddings\"]\n",
|
|
"\n",
|
|
"content_handler = ContentHandler()\n",
|
|
"\n",
|
|
"\n",
|
|
"embeddings = SagemakerEndpointEmbeddings(\n",
|
|
" # endpoint_name=\"endpoint-name\", \n",
|
|
" # credentials_profile_name=\"credentials-profile-name\", \n",
|
|
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\", \n",
|
|
" region_name=\"us-east-1\", \n",
|
|
" content_handler=content_handler\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "836e3ea5",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[0.01623339205980301,\n",
|
|
" -0.007662336342036724,\n",
|
|
" 0.018606489524245262,\n",
|
|
" 0.031968992203474045,\n",
|
|
" -0.031003747135400772,\n",
|
|
" 0.008777972310781479,\n",
|
|
" 0.1594553291797638,\n",
|
|
" -0.009521624073386192,\n",
|
|
" 0.020200366154313087,\n",
|
|
" -0.04545809328556061,\n",
|
|
" 0.013985812664031982,\n",
|
|
" -0.017674963921308517,\n",
|
|
" -0.03616964817047119,\n",
|
|
" -0.02194339968264103,\n",
|
|
" 0.021387653425335884,\n",
|
|
" 0.06459270417690277,\n",
|
|
" -0.03659535571932793,\n",
|
|
" -0.01213359646499157,\n",
|
|
" -0.043666232377290726,\n",
|
|
" -0.03515005484223366,\n",
|
|
" -0.032629866153001785,\n",
|
|
" 0.07834123075008392,\n",
|
|
" -0.021041689440608025,\n",
|
|
" 0.03372766822576523,\n",
|
|
" -0.024157941341400146,\n",
|
|
" -0.010767146944999695,\n",
|
|
" -0.042864806950092316,\n",
|
|
" 0.013539575971662998,\n",
|
|
" 0.05039731785655022,\n",
|
|
" -0.091956727206707,\n",
|
|
" 0.035494621843099594,\n",
|
|
" 0.18029741942882538,\n",
|
|
" 0.01576363667845726,\n",
|
|
" -0.04949156939983368,\n",
|
|
" -0.003976485226303339,\n",
|
|
" 0.00032106428989209235,\n",
|
|
" 0.021849628537893295,\n",
|
|
" 0.035368386656045914,\n",
|
|
" 0.04185418039560318,\n",
|
|
" 0.04899369180202484,\n",
|
|
" -0.026651302352547646,\n",
|
|
" -0.05650882050395012,\n",
|
|
" -0.03276852145791054,\n",
|
|
" -0.020723465830087662,\n",
|
|
" -0.011230835691094398,\n",
|
|
" 0.02798161283135414,\n",
|
|
" -0.010538998991250992,\n",
|
|
" 0.030317796394228935,\n",
|
|
" 0.017697133123874664,\n",
|
|
" 0.003633821150287986,\n",
|
|
" -0.008708533830940723,\n",
|
|
" -0.04946836829185486,\n",
|
|
" -0.029903240501880646,\n",
|
|
" 0.022750651463866234,\n",
|
|
" 0.09276428818702698,\n",
|
|
" 0.05072581395506859,\n",
|
|
" 0.02917262725532055,\n",
|
|
" 0.00728880288079381,\n",
|
|
" -0.011496285907924175,\n",
|
|
" -0.05313197895884514,\n",
|
|
" -0.027890320867300034,\n",
|
|
" 0.030064044520258904,\n",
|
|
" -0.06029842048883438,\n",
|
|
" -0.043088313192129135,\n",
|
|
" 0.05004483461380005,\n",
|
|
" 0.0015685138059780002,\n",
|
|
" -0.01834270916879177,\n",
|
|
" 0.046504270285367966,\n",
|
|
" -0.043405696749687195,\n",
|
|
" 0.08440472185611725,\n",
|
|
" 0.022881966084241867,\n",
|
|
" 0.013790522702038288,\n",
|
|
" 0.03525456413626671,\n",
|
|
" 0.08282686769962311,\n",
|
|
" 0.031224273145198822,\n",
|
|
" -0.032255761325359344,\n",
|
|
" 0.033190108835697174,\n",
|
|
" -0.02879202552139759,\n",
|
|
" 0.09641945362091064,\n",
|
|
" 0.014541308395564556,\n",
|
|
" -0.0425688773393631,\n",
|
|
" 0.007836293429136276,\n",
|
|
" -0.07434573024511337,\n",
|
|
" -0.03844423592090607,\n",
|
|
" 0.007907251827418804,\n",
|
|
" 0.005604865029454231,\n",
|
|
" 0.014666788280010223,\n",
|
|
" -0.015787949785590172,\n",
|
|
" -0.011632969602942467,\n",
|
|
" 0.06502652168273926,\n",
|
|
" -0.09462911635637283,\n",
|
|
" -0.05418006703257561,\n",
|
|
" 0.07266692817211151,\n",
|
|
" -0.0059609077870845795,\n",
|
|
" 0.015150884166359901,\n",
|
|
" 0.033904850482940674,\n",
|
|
" 0.04719925299286842,\n",
|
|
" -0.006713803857564926,\n",
|
|
" -0.0628967359662056,\n",
|
|
" 0.2842618525028229,\n",
|
|
" -0.007497070357203484,\n",
|
|
" 0.11969559639692307,\n",
|
|
" 0.047007955610752106,\n",
|
|
" -0.023929651826620102,\n",
|
|
" 0.015414181165397167,\n",
|
|
" -0.029861856251955032,\n",
|
|
" -0.014299873262643814,\n",
|
|
" 0.018457207828760147,\n",
|
|
" 0.05915089696645737,\n",
|
|
" -0.03441261500120163,\n",
|
|
" -0.01635487750172615,\n",
|
|
" -0.021376853808760643,\n",
|
|
" -0.01367877796292305,\n",
|
|
" -0.04958583787083626,\n",
|
|
" 0.01463531143963337,\n",
|
|
" 0.01211540587246418,\n",
|
|
" -0.005459521431475878,\n",
|
|
" 0.005695596802979708,\n",
|
|
" -0.04747362807393074,\n",
|
|
" -0.05998056381940842,\n",
|
|
" -0.042840663343667984,\n",
|
|
" 0.04042218253016472,\n",
|
|
" -0.006624337285757065,\n",
|
|
" 0.025121480226516724,\n",
|
|
" -0.054944958537817,\n",
|
|
" -0.06516158580780029,\n",
|
|
" 0.007337308023124933,\n",
|
|
" -6.10322324013495e-33,\n",
|
|
" 0.002179093426093459,\n",
|
|
" -0.073213592171669,\n",
|
|
" -0.014703890308737755,\n",
|
|
" 0.00238825217820704,\n",
|
|
" 0.02046307921409607,\n",
|
|
" -0.06456342339515686,\n",
|
|
" 0.014286896213889122,\n",
|
|
" 0.02082856185734272,\n",
|
|
" -0.07692538946866989,\n",
|
|
" 0.09246989339590073,\n",
|
|
" -0.03469334542751312,\n",
|
|
" 0.022259987890720367,\n",
|
|
" -0.0369521826505661,\n",
|
|
" -0.0876070111989975,\n",
|
|
" 0.13785682618618011,\n",
|
|
" -0.000683621852658689,\n",
|
|
" 0.0018552240217104554,\n",
|
|
" 0.07194776087999344,\n",
|
|
" -0.0633404403924942,\n",
|
|
" -0.01646324060857296,\n",
|
|
" -0.0361541211605072,\n",
|
|
" -0.006936112884432077,\n",
|
|
" 0.003252814756706357,\n",
|
|
" 0.02627389132976532,\n",
|
|
" -0.0014277833979576826,\n",
|
|
" -0.09001296013593674,\n",
|
|
" 0.008833721280097961,\n",
|
|
" -0.07455790787935257,\n",
|
|
" 0.10064911842346191,\n",
|
|
" 0.03227779641747475,\n",
|
|
" -0.016069436445832253,\n",
|
|
" 0.024673042818903923,\n",
|
|
" 0.04188213497400284,\n",
|
|
" 0.03961843252182007,\n",
|
|
" -0.028469551354646683,\n",
|
|
" -0.05262545496225357,\n",
|
|
" -0.006966825108975172,\n",
|
|
" -0.0033834113273769617,\n",
|
|
" -0.038578882813453674,\n",
|
|
" -0.010265848599374294,\n",
|
|
" -0.033789463341236115,\n",
|
|
" 0.0030778711661696434,\n",
|
|
" -0.05088731646537781,\n",
|
|
" -0.019024258479475975,\n",
|
|
" 0.05421010032296181,\n",
|
|
" 0.015494044870138168,\n",
|
|
" -0.009311210364103317,\n",
|
|
" 0.0050599598325788975,\n",
|
|
" -0.04918931797146797,\n",
|
|
" 0.03970836102962494,\n",
|
|
" 0.06579958647489548,\n",
|
|
" 0.014110234566032887,\n",
|
|
" -0.04829266294836998,\n",
|
|
" 0.05065532773733139,\n",
|
|
" 0.021345015615224838,\n",
|
|
" -0.02805492654442787,\n",
|
|
" -0.013115333393216133,\n",
|
|
" -0.03833610191941261,\n",
|
|
" 0.0081633934751153,\n",
|
|
" 0.0020320340991020203,\n",
|
|
" 0.025601046159863472,\n",
|
|
" 0.046745311468839645,\n",
|
|
" -0.07602663338184357,\n",
|
|
" 0.08589514344930649,\n",
|
|
" -0.09630884975194931,\n",
|
|
" 0.01156257651746273,\n",
|
|
" 0.047838304191827774,\n",
|
|
" -0.03707060590386391,\n",
|
|
" 0.05717772990465164,\n",
|
|
" -0.028168894350528717,\n",
|
|
" -0.06691361963748932,\n",
|
|
" 0.003909755032509565,\n",
|
|
" -0.01265989150851965,\n",
|
|
" -0.024667585268616676,\n",
|
|
" -0.04399942606687546,\n",
|
|
" 0.013469734229147434,\n",
|
|
" 0.013298758305609226,\n",
|
|
" 0.0409042127430439,\n",
|
|
" -0.012081797234714031,\n",
|
|
" -0.009779289364814758,\n",
|
|
" 0.021113228052854538,\n",
|
|
" -0.06191551312804222,\n",
|
|
" -0.010964356362819672,\n",
|
|
" 0.027119100093841553,\n",
|
|
" -0.03144009783864021,\n",
|
|
" 0.037719033658504486,\n",
|
|
" 0.02421882562339306,\n",
|
|
" -0.13700149953365326,\n",
|
|
" 0.0038421833887696266,\n",
|
|
" -0.06574120372533798,\n",
|
|
" -0.12629178166389465,\n",
|
|
" 0.018397213891148567,\n",
|
|
" 0.0019562605302780867,\n",
|
|
" -0.06581622362136841,\n",
|
|
" 0.0056412797421216965,\n",
|
|
" 6.17423926197194e-33,\n",
|
|
" 0.11609319597482681,\n",
|
|
" 0.023075049743056297,\n",
|
|
" -0.02540658414363861,\n",
|
|
" 0.021112393587827682,\n",
|
|
" -0.010050611570477486,\n",
|
|
" 0.0045014130882918835,\n",
|
|
" 0.02216450683772564,\n",
|
|
" 0.03083667904138565,\n",
|
|
" -0.065506212413311,\n",
|
|
" -0.028498610481619835,\n",
|
|
" -0.08708083629608154,\n",
|
|
" -0.027195820584893227,\n",
|
|
" 0.04075731709599495,\n",
|
|
" -0.00738579360768199,\n",
|
|
" 0.031747449189424515,\n",
|
|
" 0.020246611908078194,\n",
|
|
" 0.03285415843129158,\n",
|
|
" -0.037579674273729324,\n",
|
|
" -0.025780295953154564,\n",
|
|
" 0.044498566538095474,\n",
|
|
" -0.01600523293018341,\n",
|
|
" -0.1110200434923172,\n",
|
|
" 0.10275887697935104,\n",
|
|
" -0.044455550611019135,\n",
|
|
" -0.043082430958747864,\n",
|
|
" 0.04361744225025177,\n",
|
|
" 0.09388253092765808,\n",
|
|
" 0.03668423742055893,\n",
|
|
" -0.08740879595279694,\n",
|
|
" -0.015174541622400284,\n",
|
|
" 0.035617802292108536,\n",
|
|
" -0.056008175015449524,\n",
|
|
" -0.07729960232973099,\n",
|
|
" -0.055068857967853546,\n",
|
|
" 0.011802412569522858,\n",
|
|
" 0.0005090870545245707,\n",
|
|
" -0.04531490430235863,\n",
|
|
" 0.009635107591748238,\n",
|
|
" 0.0066973078064620495,\n",
|
|
" -0.08850639313459396,\n",
|
|
" 0.07926266640424728,\n",
|
|
" 0.03328349068760872,\n",
|
|
" 0.02206319011747837,\n",
|
|
" 0.08003410696983337,\n",
|
|
" -0.004926585592329502,\n",
|
|
" -0.012855191715061665,\n",
|
|
" -0.030001787468791008,\n",
|
|
" 0.0038301211316138506,\n",
|
|
" 0.09513995796442032,\n",
|
|
" -0.023254361003637314,\n",
|
|
" -0.01384524442255497,\n",
|
|
" -0.0006733545451425016,\n",
|
|
" 0.004949721973389387,\n",
|
|
" -0.03836912661790848,\n",
|
|
" -0.0484086349606514,\n",
|
|
" -0.04300595819950104,\n",
|
|
" -0.03302333503961563,\n",
|
|
" -0.011142191477119923,\n",
|
|
" -0.021775009110569954,\n",
|
|
" 0.009556151926517487,\n",
|
|
" -0.014081847853958607,\n",
|
|
" 0.01725372113287449,\n",
|
|
" -0.002208009362220764,\n",
|
|
" 0.043982796370983124,\n",
|
|
" -0.12186389416456223,\n",
|
|
" -0.03109029121696949,\n",
|
|
" -0.0648212656378746,\n",
|
|
" -0.03446059674024582,\n",
|
|
" -0.0009474779944866896,\n",
|
|
" 0.019224559888243675,\n",
|
|
" 0.030093936249613762,\n",
|
|
" 0.011459640227258205,\n",
|
|
" -0.031019125133752823,\n",
|
|
" 0.11076018959283829,\n",
|
|
" -0.08466918021440506,\n",
|
|
" -0.028721166774630547,\n",
|
|
" -0.006525673437863588,\n",
|
|
" 0.05877530202269554,\n",
|
|
" 0.021319882944226265,\n",
|
|
" 0.08542844653129578,\n",
|
|
" 0.05103899911046028,\n",
|
|
" -0.02113465592265129,\n",
|
|
" 0.01493119541555643,\n",
|
|
" 0.010513859800994396,\n",
|
|
" -0.023147936910390854,\n",
|
|
" -0.044208601117134094,\n",
|
|
" -0.0010544550605118275,\n",
|
|
" 0.0656798928976059,\n",
|
|
" -0.013098708353936672,\n",
|
|
" 0.0029119630344212055,\n",
|
|
" 0.03165023773908615,\n",
|
|
" 0.06931225955486298,\n",
|
|
" -0.02299979329109192,\n",
|
|
" 0.022364258766174316,\n",
|
|
" -0.04974697157740593,\n",
|
|
" -1.3714036128931184e-08,\n",
|
|
" -0.0205942764878273,\n",
|
|
" 0.047028280794620514,\n",
|
|
" -0.032210975885391235,\n",
|
|
" 0.049078319221735,\n",
|
|
" 0.0394253209233284,\n",
|
|
" 0.10298289358615875,\n",
|
|
" 0.013628372922539711,\n",
|
|
" -0.07071256637573242,\n",
|
|
" -0.001111415564082563,\n",
|
|
" 0.045793063938617706,\n",
|
|
" 0.010663686320185661,\n",
|
|
" 0.022661199793219566,\n",
|
|
" -0.00039414051570929587,\n",
|
|
" 0.04868670925498009,\n",
|
|
" 0.08181674033403397,\n",
|
|
" -0.06234998628497124,\n",
|
|
" -0.017647461965680122,\n",
|
|
" -0.05699630081653595,\n",
|
|
" -0.035604529082775116,\n",
|
|
" -0.002848744625225663,\n",
|
|
" -0.07433759421110153,\n",
|
|
" 0.05970819666981697,\n",
|
|
" -0.03040698915719986,\n",
|
|
" -0.03587964177131653,\n",
|
|
" -0.05538871884346008,\n",
|
|
" -0.007939192466437817,\n",
|
|
" -0.015285325236618519,\n",
|
|
" 0.08461211621761322,\n",
|
|
" 0.01166541874408722,\n",
|
|
" 0.03213988244533539,\n",
|
|
" 0.05643611401319504,\n",
|
|
" 0.2006419152021408,\n",
|
|
" -0.07411110401153564,\n",
|
|
" -0.018009720370173454,\n",
|
|
" 0.016179822385311127,\n",
|
|
" -0.0028461480978876352,\n",
|
|
" 0.0402149073779583,\n",
|
|
" 0.0006247136043384671,\n",
|
|
" 0.0006973804556764662,\n",
|
|
" 0.09922358393669128,\n",
|
|
" -0.029822450131177902,\n",
|
|
" -0.005783025175333023,\n",
|
|
" -0.0028224103152751923,\n",
|
|
" -0.11175407469272614,\n",
|
|
" 0.012009709142148495,\n",
|
|
" -0.009956827387213707,\n",
|
|
" 0.011468647047877312,\n",
|
|
" -0.054449401795864105,\n",
|
|
" -0.016370657831430435,\n",
|
|
" 0.022106735035777092,\n",
|
|
" 0.03950563445687294,\n",
|
|
" 0.005319684278219938,\n",
|
|
" 0.042190469801425934,\n",
|
|
" 0.08844445645809174,\n",
|
|
" 0.0810166597366333,\n",
|
|
" 0.06980433315038681,\n",
|
|
" -0.04784897342324257,\n",
|
|
" 0.01753094792366028,\n",
|
|
" -0.10126522183418274,\n",
|
|
" -0.016526369377970695,\n",
|
|
" 0.11310216039419174,\n",
|
|
" 0.0874418243765831,\n",
|
|
" 0.09520682692527771,\n",
|
|
" 0.10083616524934769]"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"query_result = embeddings.embed_query(\"foo\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "76f1b752",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"doc_results = embeddings.embed_documents([\"foo\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "221f2f0e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[[0.01623339205980301,\n",
|
|
" -0.007662336342036724,\n",
|
|
" 0.018606489524245262,\n",
|
|
" 0.031968992203474045,\n",
|
|
" -0.031003747135400772,\n",
|
|
" 0.008777972310781479,\n",
|
|
" 0.1594553291797638,\n",
|
|
" -0.009521624073386192,\n",
|
|
" 0.020200366154313087,\n",
|
|
" -0.04545809328556061,\n",
|
|
" 0.013985812664031982,\n",
|
|
" -0.017674963921308517,\n",
|
|
" -0.03616964817047119,\n",
|
|
" -0.02194339968264103,\n",
|
|
" 0.021387653425335884,\n",
|
|
" 0.06459270417690277,\n",
|
|
" -0.03659535571932793,\n",
|
|
" -0.01213359646499157,\n",
|
|
" -0.043666232377290726,\n",
|
|
" -0.03515005484223366,\n",
|
|
" -0.032629866153001785,\n",
|
|
" 0.07834123075008392,\n",
|
|
" -0.021041689440608025,\n",
|
|
" 0.03372766822576523,\n",
|
|
" -0.024157941341400146,\n",
|
|
" -0.010767146944999695,\n",
|
|
" -0.042864806950092316,\n",
|
|
" 0.013539575971662998,\n",
|
|
" 0.05039731785655022,\n",
|
|
" -0.091956727206707,\n",
|
|
" 0.035494621843099594,\n",
|
|
" 0.18029741942882538,\n",
|
|
" 0.01576363667845726,\n",
|
|
" -0.04949156939983368,\n",
|
|
" -0.003976485226303339,\n",
|
|
" 0.00032106428989209235,\n",
|
|
" 0.021849628537893295,\n",
|
|
" 0.035368386656045914,\n",
|
|
" 0.04185418039560318,\n",
|
|
" 0.04899369180202484,\n",
|
|
" -0.026651302352547646,\n",
|
|
" -0.05650882050395012,\n",
|
|
" -0.03276852145791054,\n",
|
|
" -0.020723465830087662,\n",
|
|
" -0.011230835691094398,\n",
|
|
" 0.02798161283135414,\n",
|
|
" -0.010538998991250992,\n",
|
|
" 0.030317796394228935,\n",
|
|
" 0.017697133123874664,\n",
|
|
" 0.003633821150287986,\n",
|
|
" -0.008708533830940723,\n",
|
|
" -0.04946836829185486,\n",
|
|
" -0.029903240501880646,\n",
|
|
" 0.022750651463866234,\n",
|
|
" 0.09276428818702698,\n",
|
|
" 0.05072581395506859,\n",
|
|
" 0.02917262725532055,\n",
|
|
" 0.00728880288079381,\n",
|
|
" -0.011496285907924175,\n",
|
|
" -0.05313197895884514,\n",
|
|
" -0.027890320867300034,\n",
|
|
" 0.030064044520258904,\n",
|
|
" -0.06029842048883438,\n",
|
|
" -0.043088313192129135,\n",
|
|
" 0.05004483461380005,\n",
|
|
" 0.0015685138059780002,\n",
|
|
" -0.01834270916879177,\n",
|
|
" 0.046504270285367966,\n",
|
|
" -0.043405696749687195,\n",
|
|
" 0.08440472185611725,\n",
|
|
" 0.022881966084241867,\n",
|
|
" 0.013790522702038288,\n",
|
|
" 0.03525456413626671,\n",
|
|
" 0.08282686769962311,\n",
|
|
" 0.031224273145198822,\n",
|
|
" -0.032255761325359344,\n",
|
|
" 0.033190108835697174,\n",
|
|
" -0.02879202552139759,\n",
|
|
" 0.09641945362091064,\n",
|
|
" 0.014541308395564556,\n",
|
|
" -0.0425688773393631,\n",
|
|
" 0.007836293429136276,\n",
|
|
" -0.07434573024511337,\n",
|
|
" -0.03844423592090607,\n",
|
|
" 0.007907251827418804,\n",
|
|
" 0.005604865029454231,\n",
|
|
" 0.014666788280010223,\n",
|
|
" -0.015787949785590172,\n",
|
|
" -0.011632969602942467,\n",
|
|
" 0.06502652168273926,\n",
|
|
" -0.09462911635637283,\n",
|
|
" -0.05418006703257561,\n",
|
|
" 0.07266692817211151,\n",
|
|
" -0.0059609077870845795,\n",
|
|
" 0.015150884166359901,\n",
|
|
" 0.033904850482940674,\n",
|
|
" 0.04719925299286842,\n",
|
|
" -0.006713803857564926,\n",
|
|
" -0.0628967359662056,\n",
|
|
" 0.2842618525028229,\n",
|
|
" -0.007497070357203484,\n",
|
|
" 0.11969559639692307,\n",
|
|
" 0.047007955610752106,\n",
|
|
" -0.023929651826620102,\n",
|
|
" 0.015414181165397167,\n",
|
|
" -0.029861856251955032,\n",
|
|
" -0.014299873262643814,\n",
|
|
" 0.018457207828760147,\n",
|
|
" 0.05915089696645737,\n",
|
|
" -0.03441261500120163,\n",
|
|
" -0.01635487750172615,\n",
|
|
" -0.021376853808760643,\n",
|
|
" -0.01367877796292305,\n",
|
|
" -0.04958583787083626,\n",
|
|
" 0.01463531143963337,\n",
|
|
" 0.01211540587246418,\n",
|
|
" -0.005459521431475878,\n",
|
|
" 0.005695596802979708,\n",
|
|
" -0.04747362807393074,\n",
|
|
" -0.05998056381940842,\n",
|
|
" -0.042840663343667984,\n",
|
|
" 0.04042218253016472,\n",
|
|
" -0.006624337285757065,\n",
|
|
" 0.025121480226516724,\n",
|
|
" -0.054944958537817,\n",
|
|
" -0.06516158580780029,\n",
|
|
" 0.007337308023124933,\n",
|
|
" -6.10322324013495e-33,\n",
|
|
" 0.002179093426093459,\n",
|
|
" -0.073213592171669,\n",
|
|
" -0.014703890308737755,\n",
|
|
" 0.00238825217820704,\n",
|
|
" 0.02046307921409607,\n",
|
|
" -0.06456342339515686,\n",
|
|
" 0.014286896213889122,\n",
|
|
" 0.02082856185734272,\n",
|
|
" -0.07692538946866989,\n",
|
|
" 0.09246989339590073,\n",
|
|
" -0.03469334542751312,\n",
|
|
" 0.022259987890720367,\n",
|
|
" -0.0369521826505661,\n",
|
|
" -0.0876070111989975,\n",
|
|
" 0.13785682618618011,\n",
|
|
" -0.000683621852658689,\n",
|
|
" 0.0018552240217104554,\n",
|
|
" 0.07194776087999344,\n",
|
|
" -0.0633404403924942,\n",
|
|
" -0.01646324060857296,\n",
|
|
" -0.0361541211605072,\n",
|
|
" -0.006936112884432077,\n",
|
|
" 0.003252814756706357,\n",
|
|
" 0.02627389132976532,\n",
|
|
" -0.0014277833979576826,\n",
|
|
" -0.09001296013593674,\n",
|
|
" 0.008833721280097961,\n",
|
|
" -0.07455790787935257,\n",
|
|
" 0.10064911842346191,\n",
|
|
" 0.03227779641747475,\n",
|
|
" -0.016069436445832253,\n",
|
|
" 0.024673042818903923,\n",
|
|
" 0.04188213497400284,\n",
|
|
" 0.03961843252182007,\n",
|
|
" -0.028469551354646683,\n",
|
|
" -0.05262545496225357,\n",
|
|
" -0.006966825108975172,\n",
|
|
" -0.0033834113273769617,\n",
|
|
" -0.038578882813453674,\n",
|
|
" -0.010265848599374294,\n",
|
|
" -0.033789463341236115,\n",
|
|
" 0.0030778711661696434,\n",
|
|
" -0.05088731646537781,\n",
|
|
" -0.019024258479475975,\n",
|
|
" 0.05421010032296181,\n",
|
|
" 0.015494044870138168,\n",
|
|
" -0.009311210364103317,\n",
|
|
" 0.0050599598325788975,\n",
|
|
" -0.04918931797146797,\n",
|
|
" 0.03970836102962494,\n",
|
|
" 0.06579958647489548,\n",
|
|
" 0.014110234566032887,\n",
|
|
" -0.04829266294836998,\n",
|
|
" 0.05065532773733139,\n",
|
|
" 0.021345015615224838,\n",
|
|
" -0.02805492654442787,\n",
|
|
" -0.013115333393216133,\n",
|
|
" -0.03833610191941261,\n",
|
|
" 0.0081633934751153,\n",
|
|
" 0.0020320340991020203,\n",
|
|
" 0.025601046159863472,\n",
|
|
" 0.046745311468839645,\n",
|
|
" -0.07602663338184357,\n",
|
|
" 0.08589514344930649,\n",
|
|
" -0.09630884975194931,\n",
|
|
" 0.01156257651746273,\n",
|
|
" 0.047838304191827774,\n",
|
|
" -0.03707060590386391,\n",
|
|
" 0.05717772990465164,\n",
|
|
" -0.028168894350528717,\n",
|
|
" -0.06691361963748932,\n",
|
|
" 0.003909755032509565,\n",
|
|
" -0.01265989150851965,\n",
|
|
" -0.024667585268616676,\n",
|
|
" -0.04399942606687546,\n",
|
|
" 0.013469734229147434,\n",
|
|
" 0.013298758305609226,\n",
|
|
" 0.0409042127430439,\n",
|
|
" -0.012081797234714031,\n",
|
|
" -0.009779289364814758,\n",
|
|
" 0.021113228052854538,\n",
|
|
" -0.06191551312804222,\n",
|
|
" -0.010964356362819672,\n",
|
|
" 0.027119100093841553,\n",
|
|
" -0.03144009783864021,\n",
|
|
" 0.037719033658504486,\n",
|
|
" 0.02421882562339306,\n",
|
|
" -0.13700149953365326,\n",
|
|
" 0.0038421833887696266,\n",
|
|
" -0.06574120372533798,\n",
|
|
" -0.12629178166389465,\n",
|
|
" 0.018397213891148567,\n",
|
|
" 0.0019562605302780867,\n",
|
|
" -0.06581622362136841,\n",
|
|
" 0.0056412797421216965,\n",
|
|
" 6.17423926197194e-33,\n",
|
|
" 0.11609319597482681,\n",
|
|
" 0.023075049743056297,\n",
|
|
" -0.02540658414363861,\n",
|
|
" 0.021112393587827682,\n",
|
|
" -0.010050611570477486,\n",
|
|
" 0.0045014130882918835,\n",
|
|
" 0.02216450683772564,\n",
|
|
" 0.03083667904138565,\n",
|
|
" -0.065506212413311,\n",
|
|
" -0.028498610481619835,\n",
|
|
" -0.08708083629608154,\n",
|
|
" -0.027195820584893227,\n",
|
|
" 0.04075731709599495,\n",
|
|
" -0.00738579360768199,\n",
|
|
" 0.031747449189424515,\n",
|
|
" 0.020246611908078194,\n",
|
|
" 0.03285415843129158,\n",
|
|
" -0.037579674273729324,\n",
|
|
" -0.025780295953154564,\n",
|
|
" 0.044498566538095474,\n",
|
|
" -0.01600523293018341,\n",
|
|
" -0.1110200434923172,\n",
|
|
" 0.10275887697935104,\n",
|
|
" -0.044455550611019135,\n",
|
|
" -0.043082430958747864,\n",
|
|
" 0.04361744225025177,\n",
|
|
" 0.09388253092765808,\n",
|
|
" 0.03668423742055893,\n",
|
|
" -0.08740879595279694,\n",
|
|
" -0.015174541622400284,\n",
|
|
" 0.035617802292108536,\n",
|
|
" -0.056008175015449524,\n",
|
|
" -0.07729960232973099,\n",
|
|
" -0.055068857967853546,\n",
|
|
" 0.011802412569522858,\n",
|
|
" 0.0005090870545245707,\n",
|
|
" -0.04531490430235863,\n",
|
|
" 0.009635107591748238,\n",
|
|
" 0.0066973078064620495,\n",
|
|
" -0.08850639313459396,\n",
|
|
" 0.07926266640424728,\n",
|
|
" 0.03328349068760872,\n",
|
|
" 0.02206319011747837,\n",
|
|
" 0.08003410696983337,\n",
|
|
" -0.004926585592329502,\n",
|
|
" -0.012855191715061665,\n",
|
|
" -0.030001787468791008,\n",
|
|
" 0.0038301211316138506,\n",
|
|
" 0.09513995796442032,\n",
|
|
" -0.023254361003637314,\n",
|
|
" -0.01384524442255497,\n",
|
|
" -0.0006733545451425016,\n",
|
|
" 0.004949721973389387,\n",
|
|
" -0.03836912661790848,\n",
|
|
" -0.0484086349606514,\n",
|
|
" -0.04300595819950104,\n",
|
|
" -0.03302333503961563,\n",
|
|
" -0.011142191477119923,\n",
|
|
" -0.021775009110569954,\n",
|
|
" 0.009556151926517487,\n",
|
|
" -0.014081847853958607,\n",
|
|
" 0.01725372113287449,\n",
|
|
" -0.002208009362220764,\n",
|
|
" 0.043982796370983124,\n",
|
|
" -0.12186389416456223,\n",
|
|
" -0.03109029121696949,\n",
|
|
" -0.0648212656378746,\n",
|
|
" -0.03446059674024582,\n",
|
|
" -0.0009474779944866896,\n",
|
|
" 0.019224559888243675,\n",
|
|
" 0.030093936249613762,\n",
|
|
" 0.011459640227258205,\n",
|
|
" -0.031019125133752823,\n",
|
|
" 0.11076018959283829,\n",
|
|
" -0.08466918021440506,\n",
|
|
" -0.028721166774630547,\n",
|
|
" -0.006525673437863588,\n",
|
|
" 0.05877530202269554,\n",
|
|
" 0.021319882944226265,\n",
|
|
" 0.08542844653129578,\n",
|
|
" 0.05103899911046028,\n",
|
|
" -0.02113465592265129,\n",
|
|
" 0.01493119541555643,\n",
|
|
" 0.010513859800994396,\n",
|
|
" -0.023147936910390854,\n",
|
|
" -0.044208601117134094,\n",
|
|
" -0.0010544550605118275,\n",
|
|
" 0.0656798928976059,\n",
|
|
" -0.013098708353936672,\n",
|
|
" 0.0029119630344212055,\n",
|
|
" 0.03165023773908615,\n",
|
|
" 0.06931225955486298,\n",
|
|
" -0.02299979329109192,\n",
|
|
" 0.022364258766174316,\n",
|
|
" -0.04974697157740593,\n",
|
|
" -1.3714036128931184e-08,\n",
|
|
" -0.0205942764878273,\n",
|
|
" 0.047028280794620514,\n",
|
|
" -0.032210975885391235,\n",
|
|
" 0.049078319221735,\n",
|
|
" 0.0394253209233284,\n",
|
|
" 0.10298289358615875,\n",
|
|
" 0.013628372922539711,\n",
|
|
" -0.07071256637573242,\n",
|
|
" -0.001111415564082563,\n",
|
|
" 0.045793063938617706,\n",
|
|
" 0.010663686320185661,\n",
|
|
" 0.022661199793219566,\n",
|
|
" -0.00039414051570929587,\n",
|
|
" 0.04868670925498009,\n",
|
|
" 0.08181674033403397,\n",
|
|
" -0.06234998628497124,\n",
|
|
" -0.017647461965680122,\n",
|
|
" -0.05699630081653595,\n",
|
|
" -0.035604529082775116,\n",
|
|
" -0.002848744625225663,\n",
|
|
" -0.07433759421110153,\n",
|
|
" 0.05970819666981697,\n",
|
|
" -0.03040698915719986,\n",
|
|
" -0.03587964177131653,\n",
|
|
" -0.05538871884346008,\n",
|
|
" -0.007939192466437817,\n",
|
|
" -0.015285325236618519,\n",
|
|
" 0.08461211621761322,\n",
|
|
" 0.01166541874408722,\n",
|
|
" 0.03213988244533539,\n",
|
|
" 0.05643611401319504,\n",
|
|
" 0.2006419152021408,\n",
|
|
" -0.07411110401153564,\n",
|
|
" -0.018009720370173454,\n",
|
|
" 0.016179822385311127,\n",
|
|
" -0.0028461480978876352,\n",
|
|
" 0.0402149073779583,\n",
|
|
" 0.0006247136043384671,\n",
|
|
" 0.0006973804556764662,\n",
|
|
" 0.09922358393669128,\n",
|
|
" -0.029822450131177902,\n",
|
|
" -0.005783025175333023,\n",
|
|
" -0.0028224103152751923,\n",
|
|
" -0.11175407469272614,\n",
|
|
" 0.012009709142148495,\n",
|
|
" -0.009956827387213707,\n",
|
|
" 0.011468647047877312,\n",
|
|
" -0.054449401795864105,\n",
|
|
" -0.016370657831430435,\n",
|
|
" 0.022106735035777092,\n",
|
|
" 0.03950563445687294,\n",
|
|
" 0.005319684278219938,\n",
|
|
" 0.042190469801425934,\n",
|
|
" 0.08844445645809174,\n",
|
|
" 0.0810166597366333,\n",
|
|
" 0.06980433315038681,\n",
|
|
" -0.04784897342324257,\n",
|
|
" 0.01753094792366028,\n",
|
|
" -0.10126522183418274,\n",
|
|
" -0.016526369377970695,\n",
|
|
" 0.11310216039419174,\n",
|
|
" 0.0874418243765831,\n",
|
|
" 0.09520682692527771,\n",
|
|
" 0.10083616524934769]]"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"doc_results"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "aaad49f8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "langchain",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|