langchain/docs/modules/indexes/examples/embeddings.ipynb
Philipp Schmid 064be93edf
[Embeddings] Add SageMaker Endpoint Embedding class (#1859)
# What does this PR do? 

This PR adds similar to `llms` a SageMaker-powered `embeddings` class.
This is helpful if you want to leverage Hugging Face models on SageMaker
for creating your indexes.

I added a example into the
[docs/modules/indexes/examples/embeddings.ipynb](https://github.com/hwchase17/langchain/compare/master...philschmid:add-sm-embeddings?expand=1#diff-e82629e2894974ec87856aedd769d4bdfe400314b03734f32bee5990bc7e8062)
document. The example currently includes some `_### TEMPORARY: Showing
how to deploy a SageMaker Endpoint from a Hugging Face model ###_ ` code
showing how you can deploy a sentence-transformers to SageMaker and then
run the methods of the embeddings class.

@hwchase17 please let me know if/when i should remove the `_###
TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging
Face model ###_` in the description i linked to a detail blog on how to
deploy a Sentence Transformers so i think we don't need to include those
steps here.

I also reused the `ContentHandlerBase` from
`langchain.llms.sagemaker_endpoint` and changed the output type to `any`
since it is depending on the implementation.
2023-03-21 21:51:48 -07:00

2135 lines
64 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "249b4058",
"metadata": {},
"source": [
"# Embeddings\n",
"\n",
"This notebook goes over how to use the Embedding class in LangChain.\n",
"\n",
"The Embedding class is a class designed for interfacing with embeddings. There are lots of Embedding providers (OpenAI, Cohere, Hugging Face, etc) - this class is designed to provide a standard interface for all of them.\n",
"\n",
"Embeddings create a vector representation of a piece of text. This is useful because it means we can think about text in the vector space, and do things like semantic search where we look for pieces of text that are most similar in the vector space.\n",
"\n",
"The base Embedding class in LangChain exposes two methods: `embed_documents` and `embed_query`. The largest difference is that these two methods have different interfaces: one works over multiple documents, while the other works over a single document. Besides this, another reason for having these as two separate methods is that some embedding providers have different embedding methods for documents (to be searched over) vs queries (the search query itself)."
]
},
{
"cell_type": "markdown",
"id": "278b6c63",
"metadata": {},
"source": [
"## OpenAI\n",
"\n",
"Let's load the OpenAI Embedding class."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0be1af71",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import OpenAIEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2c66e5da",
"metadata": {},
"outputs": [],
"source": [
"embeddings = OpenAIEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "01370375",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bfb6142c",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0356c3b7",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "bb61bbeb",
"metadata": {},
"source": [
"Let's load the OpenAI Embedding class with first generation models (e.g. text-search-ada-doc-001/text-search-ada-query-001). Note: These are not recommended models - see [here](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0b072cc",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings.openai import OpenAIEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a56b70f5",
"metadata": {},
"outputs": [],
"source": [
"embeddings = OpenAIEmbeddings(model_name=\"ada\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14aefb64",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c39ed33",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3221db6",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c3852491",
"metadata": {},
"source": [
"## AzureOpenAI\n",
"\n",
"Let's load the OpenAI Embedding class with environment variables set to indicate to use Azure endpoints."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b40f827",
"metadata": {},
"outputs": [],
"source": [
"# set the environment variables needed for openai package to know to reach out to azure\n",
"import os\n",
"\n",
"os.environ[\"OPENAI_API_TYPE\"] = \"azure\"\n",
"os.environ[\"OPENAI_API_BASE\"] = \"https://<your-endpoint.openai.azure.com/\"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"your AzureOpenAI key\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb36d16c",
"metadata": {},
"outputs": [],
"source": [
"embeddings = OpenAIEmbeddings(model=\"your-embeddings-deployment-name\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "228abcbb",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60dd7fad",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83bc1a72",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text])"
]
},
{
"cell_type": "markdown",
"id": "42f76e43",
"metadata": {},
"source": [
"## Cohere\n",
"\n",
"Let's load the Cohere Embedding class."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ca9e2b3a",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6b82f59f",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import CohereEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "26895c60",
"metadata": {},
"outputs": [],
"source": [
"embeddings = CohereEmbeddings(cohere_api_key=cohere_api_key)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "eea52814",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fbe167bf",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "38ad3b20",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text])"
]
},
{
"cell_type": "markdown",
"id": "ed47bb62",
"metadata": {},
"source": [
"## Hugging Face Hub\n",
"Let's load the Hugging Face Embedding class."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "861521a9",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import HuggingFaceEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ff9be586",
"metadata": {},
"outputs": [],
"source": [
"embeddings = HuggingFaceEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "d0a98ae9",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "5d6c682b",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "bb5e74c0",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text])"
]
},
{
"cell_type": "markdown",
"id": "fff4734f",
"metadata": {},
"source": [
"## TensorflowHub\n",
"Let's load the TensorflowHub Embedding class."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f822104b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import TensorflowHubEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bac84e46",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-01-30 23:53:01.652176: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2023-01-30 23:53:34.362802: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"embeddings = TensorflowHubEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4790d770",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f556dcdb",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "markdown",
"id": "59428e05",
"metadata": {},
"source": [
"## InstructEmbeddings\n",
"Let's load the HuggingFace instruct Embeddings class."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "92c5b61e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import HuggingFaceInstructEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "062547b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"load INSTRUCTOR_Transformer\n",
"max_seq_length 512\n"
]
}
],
"source": [
"embeddings = HuggingFaceInstructEmbeddings(\n",
" query_instruction=\"Represent the query for retrieval: \"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e1dcc4bd",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "90f0db94",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "markdown",
"id": "eec4efda",
"metadata": {},
"source": [
"## Self Hosted Embeddings\n",
"Let's load the SelfHostedEmbeddings, SelfHostedHuggingFaceEmbeddings, and SelfHostedHuggingFaceInstructEmbeddings classes."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d338722a",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from langchain.embeddings import (\n",
" SelfHostedEmbeddings,\n",
" SelfHostedHuggingFaceEmbeddings,\n",
" SelfHostedHuggingFaceInstructEmbeddings,\n",
")\n",
"import runhouse as rh"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "146559e8",
"metadata": {},
"outputs": [],
"source": [
"# For an on-demand A100 with GCP, Azure, or Lambda\n",
"gpu = rh.cluster(name=\"rh-a10x\", instance_type=\"A100:1\", use_spot=False)\n",
"\n",
"# For an on-demand A10G with AWS (no single A100s on AWS)\n",
"# gpu = rh.cluster(name='rh-a10x', instance_type='g5.2xlarge', provider='aws')\n",
"\n",
"# For an existing cluster\n",
"# gpu = rh.cluster(ips=['<ip of the cluster>'],\n",
"# ssh_creds={'ssh_user': '...', 'ssh_private_key':'<path_to_key>'},\n",
"# name='my-cluster')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1230f7df",
"metadata": {},
"outputs": [],
"source": [
"embeddings = SelfHostedHuggingFaceEmbeddings(hardware=gpu)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2684e928",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1dc5e606",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "markdown",
"id": "cef9cc54",
"metadata": {},
"source": [
"And similarly for SelfHostedHuggingFaceInstructEmbeddings:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81a17ca3",
"metadata": {},
"outputs": [],
"source": [
"embeddings = SelfHostedHuggingFaceInstructEmbeddings(hardware=gpu)"
]
},
{
"cell_type": "markdown",
"id": "5a33d1c8",
"metadata": {},
"source": [
"Now let's load an embedding model with a custom load function:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c4af5679",
"metadata": {},
"outputs": [],
"source": [
"def get_pipeline():\n",
" from transformers import (\n",
" AutoModelForCausalLM,\n",
" AutoTokenizer,\n",
" pipeline,\n",
" ) # Must be inside the function in notebooks\n",
"\n",
" model_id = \"facebook/bart-base\"\n",
" tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
" model = AutoModelForCausalLM.from_pretrained(model_id)\n",
" return pipeline(\"feature-extraction\", model=model, tokenizer=tokenizer)\n",
"\n",
"\n",
"def inference_fn(pipeline, prompt):\n",
" # Return last hidden state of the model\n",
" if isinstance(prompt, list):\n",
" return [emb[0][-1] for emb in pipeline(prompt)]\n",
" return pipeline(prompt)[0][-1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8654334b",
"metadata": {},
"outputs": [],
"source": [
"embeddings = SelfHostedEmbeddings(\n",
" model_load_fn=get_pipeline,\n",
" hardware=gpu,\n",
" model_reqs=[\"./\", \"torch\", \"transformers\"],\n",
" inference_fn=inference_fn,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc1bfd0f",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(text)"
]
},
{
"cell_type": "markdown",
"id": "f9c02c78",
"metadata": {},
"source": [
"## Fake Embeddings\n",
"\n",
"LangChain also provides a fake embedding class. You can use this to test your pipelines."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2ffc2e4b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import FakeEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "80777571",
"metadata": {},
"outputs": [],
"source": [
"embeddings = FakeEmbeddings(size=1352)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3ec9d8f0",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3b9ae9e1",
"metadata": {},
"outputs": [],
"source": [
"doc_results = embeddings.embed_documents([\"foo\"])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "1f83f273",
"metadata": {},
"source": [
"## SageMaker Endpoint Embeddings\n",
"\n",
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker learn more [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88d366bd",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install langchain boto3"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c5855922",
"metadata": {},
"source": [
"## _### TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging Face model ###_ "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e0ddd9b4",
"metadata": {},
"outputs": [],
"source": [
"!pip install sagemaker --quiet\n",
"\n",
"import os \n",
"os.environ[\"AWS_DEFAULT_REGION\"] = \"us-east-1\"\n",
"import boto3\n",
"from sagemaker import Session\n",
"# get sagemaker execution role to deploy\n",
"iam = boto3.client('iam')\n",
"role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
"sess = Session()\n",
"# create code/ dir\n",
"os.makedirs(\"model/code\", exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "86ce76c6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Writing model/code/inference.py\n"
]
}
],
"source": [
"%%writefile model/code/inference.py\n",
"\n",
"from transformers import AutoTokenizer, AutoModel\n",
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"# Helper: Mean Pooling - Take attention mask into account for correct averaging\n",
"def mean_pooling(model_output, attention_mask):\n",
" token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n",
" input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
" return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
"\n",
"\n",
"def model_fn(model_dir):\n",
" # Load model from HuggingFace Hub\n",
" tokenizer = AutoTokenizer.from_pretrained(\"sentence-transformers/all-MiniLM-L6-v2\")\n",
" model = AutoModel.from_pretrained(\"sentence-transformers/all-MiniLM-L6-v2\")\n",
" return model, tokenizer\n",
"\n",
"def predict_fn(data, model_and_tokenizer):\n",
" # destruct model and tokenizer\n",
" model, tokenizer = model_and_tokenizer\n",
"\n",
" # Tokenize sentences\n",
" sentences = data.pop(\"inputs\", data)\n",
" encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')\n",
"\n",
" # Compute token embeddings\n",
" with torch.no_grad():\n",
" model_output = model(**encoded_input)\n",
"\n",
" # Perform pooling\n",
" sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n",
"\n",
" # Normalize embeddings\n",
" sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)\n",
"\n",
" # return dictonary, which will be json serializable\n",
" return {\"embeddings\": sentence_embeddings[0].tolist()}\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "24b809d4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"code/\n",
"code/inference.py\n",
"----!"
]
}
],
"source": [
"from sagemaker.s3 import S3Uploader\n",
"from sagemaker.huggingface.model import HuggingFaceModel\n",
"\n",
"# create model.tar.gz and upload to s3 \n",
"parent_dir=os.getcwd()\n",
"# change to model dir\n",
"os.chdir(\"model\")\n",
"# use pigz for faster and parallel compression\n",
"!tar zcvf model.tar.gz *\n",
"# change back to parent dir\n",
"os.chdir(parent_dir)\n",
"\n",
"\n",
"# upload model.tar.gz to s3\n",
"s3_model_uri = S3Uploader.upload(local_path=\"model/model.tar.gz\", desired_s3_uri=f\"s3://{sess.default_bucket()}/embeddings\")\n",
"\n",
"# create Hugging Face Model Class\n",
"huggingface_model = HuggingFaceModel(\n",
" model_data=s3_model_uri, # path to your model and script\n",
" role=role, # iam role with permissions to create an Endpoint\n",
" transformers_version=\"4.26\", # transformers version used\n",
" pytorch_version=\"1.13\", # pytorch version used\n",
" py_version='py39', # python version used\n",
")\n",
"\n",
"# deploy the endpoint endpoint\n",
"predictor = huggingface_model.deploy(1,\"ml.m5.2xlarge\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "324213fd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'huggingface-pytorch-inference-2023-03-21-16-14-03-834'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictor.endpoint_name"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3dff3efa",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'embeddings': [-0.03833858296275139,\n",
" 0.12346473336219788,\n",
" -0.028642961755394936,\n",
" 0.05365271493792534,\n",
" 0.008845399133861065,\n",
" -0.039839327335357666,\n",
" -0.07300589978694916,\n",
" 0.04777129739522934,\n",
" -0.03046245686709881,\n",
" 0.054979756474494934,\n",
" 0.08505291491746902,\n",
" 0.03665667772293091,\n",
" -0.0053200023248791695,\n",
" -0.002233208389952779,\n",
" -0.06071101501584053,\n",
" -0.027237888425588608,\n",
" -0.011351668275892735,\n",
" -0.04243773967027664,\n",
" 0.009129947051405907,\n",
" 0.10081552714109421,\n",
" 0.075787253677845,\n",
" 0.06911724805831909,\n",
" 0.009857476688921452,\n",
" -0.0018377384403720498,\n",
" 0.02624901942908764,\n",
" 0.03290242329239845,\n",
" -0.07177436351776123,\n",
" 0.028384245932102203,\n",
" 0.06170952320098877,\n",
" -0.05252952501177788,\n",
" 0.033661700785160065,\n",
" 0.07446815073490143,\n",
" 0.07536035776138306,\n",
" 0.03538404032588005,\n",
" 0.0671340748667717,\n",
" 0.01079804077744484,\n",
" 0.08167019486427307,\n",
" 0.01656288281083107,\n",
" 0.03283063322305679,\n",
" 0.03632563352584839,\n",
" 0.002172857290133834,\n",
" -0.09895739704370499,\n",
" 0.005046747159212828,\n",
" 0.050896503031253815,\n",
" 0.009287566877901554,\n",
" 0.024507733061909676,\n",
" -0.0644078254699707,\n",
" 0.0019362837774679065,\n",
" -0.079103484749794,\n",
" 0.020850397646427155,\n",
" -0.01922827586531639,\n",
" -0.02805466018617153,\n",
" -0.07059794664382935,\n",
" -0.007083615753799677,\n",
" 0.01040570717304945,\n",
" 0.038834139704704285,\n",
" 0.01765601523220539,\n",
" -0.019606105983257294,\n",
" -0.020058417692780495,\n",
" 0.018083792179822922,\n",
" -0.00017212114471476525,\n",
" 0.013043343089520931,\n",
" -0.09337250143289566,\n",
" 0.08453577756881714,\n",
" 0.11705499142408371,\n",
" 0.057413410395383835,\n",
" -0.022439058870077133,\n",
" -0.03677624836564064,\n",
" -0.03434618189930916,\n",
" -0.06383830308914185,\n",
" -0.06846101582050323,\n",
" -0.005553076509386301,\n",
" 0.044378429651260376,\n",
" 0.016669290140271187,\n",
" 0.030911751091480255,\n",
" -0.01975969970226288,\n",
" -0.024855101481080055,\n",
" -0.05904391035437584,\n",
" 0.0945875272154808,\n",
" -0.06530515849590302,\n",
" -0.05597255751490593,\n",
" -0.03284724801778793,\n",
" 0.00811521615833044,\n",
" -0.002234684070572257,\n",
" 0.002023296197876334,\n",
" 0.07942128926515579,\n",
" 0.08518771082162857,\n",
" 0.007815245538949966,\n",
" -0.01374559011310339,\n",
" 0.031104223802685738,\n",
" 0.010080904699862003,\n",
" -0.03275560960173607,\n",
" 0.007714808918535709,\n",
" -0.006191879045218229,\n",
" -0.05613413453102112,\n",
" 0.004364899825304747,\n",
" -0.01403757743537426,\n",
" -0.039304714649915695,\n",
" 0.07822350412607193,\n",
" 0.07393720000982285,\n",
" 0.05619140341877937,\n",
" 0.003301335731521249,\n",
" 0.04155803844332695,\n",
" -0.010387539863586426,\n",
" -0.13272696733474731,\n",
" -0.10473112016916275,\n",
" 0.018451808020472527,\n",
" -0.07520624995231628,\n",
" 0.04954085499048233,\n",
" -0.028530888259410858,\n",
" -0.01358408946543932,\n",
" -0.037112679332494736,\n",
" -0.06756578385829926,\n",
" -0.019552525132894516,\n",
" -0.010211824439466,\n",
" -0.051934875547885895,\n",
" -0.05941231921315193,\n",
" 0.016754044219851494,\n",
" 0.04098018631339073,\n",
" 0.001522376318462193,\n",
" 0.08095283806324005,\n",
" 0.002651068614795804,\n",
" -0.03870720788836479,\n",
" -0.04703034833073616,\n",
" -0.05854427441954613,\n",
" -0.029478492215275764,\n",
" 0.03882651776075363,\n",
" -8.102625254868425e-33,\n",
" -0.012914206832647324,\n",
" -0.014458492398262024,\n",
" -0.022368784993886948,\n",
" 0.1056450605392456,\n",
" 0.0037274654023349285,\n",
" 0.005939559079706669,\n",
" -0.023657256737351418,\n",
" 0.041163913905620575,\n",
" -0.07411694526672363,\n",
" 0.007076926529407501,\n",
" 0.0018349214224144816,\n",
" -0.03314222767949104,\n",
" 0.006818821653723717,\n",
" 0.04693515598773956,\n",
" -0.03836120665073395,\n",
" 0.05861291661858559,\n",
" -0.0840379074215889,\n",
" 0.11954139918088913,\n",
" -0.025204092264175415,\n",
" 0.02761165052652359,\n",
" 0.0244757030159235,\n",
" 0.014137371443212032,\n",
" 0.0128665491938591,\n",
" -0.05779572203755379,\n",
" -0.031691741198301315,\n",
" -0.0029006320983171463,\n",
" -0.027254171669483185,\n",
" -0.027451230213046074,\n",
" -0.03404244780540466,\n",
" 0.020136823877692223,\n",
" 0.022654512897133827,\n",
" 0.030933434143662453,\n",
" -0.045505885034799576,\n",
" -0.0025163793470710516,\n",
" 0.01510235108435154,\n",
" 0.09668111801147461,\n",
" 0.001809411682188511,\n",
" -0.05403870716691017,\n",
" 0.0025403527542948723,\n",
" 0.006051000207662582,\n",
" -0.056302234530448914,\n",
" -0.028254246339201927,\n",
" 0.06966646015644073,\n",
" 0.04410792514681816,\n",
" 0.039832279086112976,\n",
" -0.0419430211186409,\n",
" -0.0038099137600511312,\n",
" -0.04156690835952759,\n",
" 0.09482309967279434,\n",
" 0.019028929993510246,\n",
" -0.04011702537536621,\n",
" 0.0324222669005394,\n",
" 0.012565849348902702,\n",
" -0.056325893849134445,\n",
" 0.04461190849542618,\n",
" 0.04928917437791824,\n",
" 0.017442630603909492,\n",
" 0.05323149263858795,\n",
" -0.020876457914710045,\n",
" 0.061462536454200745,\n",
" -0.014837260358035564,\n",
" 0.07423629611730576,\n",
" -0.0576944537460804,\n",
" 0.049852192401885986,\n",
" -0.05890402942895889,\n",
" -0.0006539729074575007,\n",
" -0.10970547795295715,\n",
" -0.06829895824193954,\n",
" 0.13056595623493195,\n",
" -0.011906635947525501,\n",
" -0.0159984789788723,\n",
" -0.0211041159927845,\n",
" -0.007144191302359104,\n",
" -0.0164438858628273,\n",
" -0.016906214877963066,\n",
" -0.04813709110021591,\n",
" 0.015731733292341232,\n",
" 0.030654815956950188,\n",
" -0.004599860403686762,\n",
" -0.03823969140648842,\n",
" -0.04718682914972305,\n",
" -0.08068915456533432,\n",
" -0.011494779027998447,\n",
" -0.05190776288509369,\n",
" -0.04332379251718521,\n",
" -0.019109943881630898,\n",
" 0.036341868340969086,\n",
" -0.06575313955545425,\n",
" -0.014969361014664173,\n",
" -0.0911363959312439,\n",
" 0.035127948969602585,\n",
" 0.019904181361198425,\n",
" -0.055992890149354935,\n",
" -0.04273851588368416,\n",
" 0.11667020618915558,\n",
" 4.7537233992963164e-33,\n",
" -0.04277687147259712,\n",
" 0.010693217627704144,\n",
" -0.08699914813041687,\n",
" 0.11428382992744446,\n",
" 0.026194244623184204,\n",
" 0.008768039755523205,\n",
" 0.08940346539020538,\n",
" -0.0019060149788856506,\n",
" -0.0455072745680809,\n",
" 0.08432017266750336,\n",
" 0.011060485616326332,\n",
" 0.000260289350990206,\n",
" -0.00023178635456133634,\n",
" -0.0015942883910611272,\n",
" 0.0015580946346744895,\n",
" -0.025324126705527306,\n",
" -0.03786805272102356,\n",
" -0.0546313114464283,\n",
" 0.004270816687494516,\n",
" 0.016222011297941208,\n",
" -0.04763113334774971,\n",
" 0.11077607423067093,\n",
" 0.045782990753650665,\n",
" 0.07989457994699478,\n",
" -0.006792569998651743,\n",
" -0.010313649661839008,\n",
" 0.006975427269935608,\n",
" -0.09530742466449738,\n",
" -0.014356936328113079,\n",
" -0.013479162007570267,\n",
" -0.009381195530295372,\n",
" -0.0026153195649385452,\n",
" -0.12162390351295471,\n",
" 0.07765249162912369,\n",
" 0.009094372391700745,\n",
" -0.10183481127023697,\n",
" 0.13146239519119263,\n",
" -0.04587067291140556,\n",
" -0.009605005383491516,\n",
" 0.024302706122398376,\n",
" 0.045921340584754944,\n",
" 0.08771276473999023,\n",
" 0.055159058421850204,\n",
" 0.047116719186306,\n",
" -0.022800585255026817,\n",
" 0.05540422350168228,\n",
" 0.03942396119236946,\n",
" -0.06854791939258575,\n",
" 0.07696892321109772,\n",
" 0.0264807790517807,\n",
" 0.013421732001006603,\n",
" -0.03159027546644211,\n",
" 0.02122318185865879,\n",
" -0.02458374947309494,\n",
" -0.09490033239126205,\n",
" 0.05001789703965187,\n",
" -0.07885674387216568,\n",
" -0.0469261035323143,\n",
" -0.009405327029526234,\n",
" 0.06844945251941681,\n",
" -0.019532756879925728,\n",
" 0.08325397968292236,\n",
" -0.0020212731324136257,\n",
" 0.07861411571502686,\n",
" 0.009707036428153515,\n",
" -0.08329329639673233,\n",
" -0.08883728086948395,\n",
" 0.026159727945923805,\n",
" -0.0036121727898716927,\n",
" 0.0021212503779679537,\n",
" 0.06756487488746643,\n",
" -0.04351912811398506,\n",
" -0.031103378161787987,\n",
" -0.1055448055267334,\n",
" 0.08162888139486313,\n",
" -0.11693760007619858,\n",
" 0.0012153959833085537,\n",
" -0.042226288467645645,\n",
" -0.025040708482265472,\n",
" -0.05382077395915985,\n",
" 0.046688906848430634,\n",
" -0.004659516736865044,\n",
" -0.049144256860017776,\n",
" 0.05339549481868744,\n",
" -0.016824593767523766,\n",
" -0.018911045044660568,\n",
" 0.0021526776254177094,\n",
" 0.010545731522142887,\n",
" -0.02843359299004078,\n",
" 0.06319320946931839,\n",
" -0.041760899126529694,\n",
" 0.03648762032389641,\n",
" -0.028613677248358727,\n",
" 0.012441876344382763,\n",
" -0.030993392691016197,\n",
" -1.827941886745066e-08,\n",
" -0.03364746645092964,\n",
" -0.010457276366651058,\n",
" 0.006326176226139069,\n",
" -0.03394529968500137,\n",
" -0.03437081351876259,\n",
" 0.043725401163101196,\n",
" 0.07607871294021606,\n",
" -0.05076980963349342,\n",
" -0.06551552563905716,\n",
" -0.023710858076810837,\n",
" 0.05217289924621582,\n",
" 0.008229373954236507,\n",
" -0.05053586885333061,\n",
" -0.0046344115398824215,\n",
" 0.04596329480409622,\n",
" -0.048263613134622574,\n",
" -0.007646505255252123,\n",
" -0.0246701892465353,\n",
" -0.05899248272180557,\n",
" 0.02179579623043537,\n",
" -0.033197544515132904,\n",
" 0.026267115026712418,\n",
" 0.019565267488360405,\n",
" 0.022036483511328697,\n",
" -0.02707892283797264,\n",
" 0.07815380394458771,\n",
" 0.03259186074137688,\n",
" 0.10126295685768127,\n",
" 0.007166724652051926,\n",
" -0.031028350815176964,\n",
" 0.04080115631222725,\n",
" 0.10805943608283997,\n",
" -0.00941381324082613,\n",
" -0.01028114091604948,\n",
" 0.037279773503541946,\n",
" 0.11904413253068924,\n",
" 0.04982069879770279,\n",
" 0.05209505558013916,\n",
" 0.020246144384145737,\n",
" 0.05551902949810028,\n",
" -0.10270132124423981,\n",
" -0.009933318942785263,\n",
" -0.022510290145874023,\n",
" 0.03311152011156082,\n",
" 0.05227212607860565,\n",
" -0.029383286833763123,\n",
" -0.1383359581232071,\n",
" -0.014143865555524826,\n",
" -0.037659481167793274,\n",
" -0.08339183777570724,\n",
" -0.0034869578666985035,\n",
" -0.0415429063141346,\n",
" 0.04902830719947815,\n",
" 0.02155115082859993,\n",
" -0.040210600942373276,\n",
" 0.008557669818401337,\n",
" 0.046616844832897186,\n",
" -0.004114149138331413,\n",
" -0.03815949708223343,\n",
" -0.015223635360598564,\n",
" 0.12486445158720016,\n",
" 0.08800436556339264,\n",
" 0.08585748821496964,\n",
" -0.015338928438723087]}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictor.predict({\"inputs\": \"This is a test document.\"})"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3764f159",
"metadata": {},
"source": [
"## _### END TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging Face model ###_ "
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1e9b926a",
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
"import json\n",
"\n",
"\n",
"class ContentHandler(ContentHandlerBase):\n",
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
" return input_str.encode('utf-8')\n",
" \n",
" def transform_output(self, output: bytes) -> str:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[\"embeddings\"]\n",
"\n",
"content_handler = ContentHandler()\n",
"\n",
"\n",
"embeddings = SagemakerEndpointEmbeddings(\n",
" # endpoint_name=\"endpoint-name\", \n",
" # credentials_profile_name=\"credentials-profile-name\", \n",
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\", \n",
" region_name=\"us-east-1\", \n",
" content_handler=content_handler\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "836e3ea5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.01623339205980301,\n",
" -0.007662336342036724,\n",
" 0.018606489524245262,\n",
" 0.031968992203474045,\n",
" -0.031003747135400772,\n",
" 0.008777972310781479,\n",
" 0.1594553291797638,\n",
" -0.009521624073386192,\n",
" 0.020200366154313087,\n",
" -0.04545809328556061,\n",
" 0.013985812664031982,\n",
" -0.017674963921308517,\n",
" -0.03616964817047119,\n",
" -0.02194339968264103,\n",
" 0.021387653425335884,\n",
" 0.06459270417690277,\n",
" -0.03659535571932793,\n",
" -0.01213359646499157,\n",
" -0.043666232377290726,\n",
" -0.03515005484223366,\n",
" -0.032629866153001785,\n",
" 0.07834123075008392,\n",
" -0.021041689440608025,\n",
" 0.03372766822576523,\n",
" -0.024157941341400146,\n",
" -0.010767146944999695,\n",
" -0.042864806950092316,\n",
" 0.013539575971662998,\n",
" 0.05039731785655022,\n",
" -0.091956727206707,\n",
" 0.035494621843099594,\n",
" 0.18029741942882538,\n",
" 0.01576363667845726,\n",
" -0.04949156939983368,\n",
" -0.003976485226303339,\n",
" 0.00032106428989209235,\n",
" 0.021849628537893295,\n",
" 0.035368386656045914,\n",
" 0.04185418039560318,\n",
" 0.04899369180202484,\n",
" -0.026651302352547646,\n",
" -0.05650882050395012,\n",
" -0.03276852145791054,\n",
" -0.020723465830087662,\n",
" -0.011230835691094398,\n",
" 0.02798161283135414,\n",
" -0.010538998991250992,\n",
" 0.030317796394228935,\n",
" 0.017697133123874664,\n",
" 0.003633821150287986,\n",
" -0.008708533830940723,\n",
" -0.04946836829185486,\n",
" -0.029903240501880646,\n",
" 0.022750651463866234,\n",
" 0.09276428818702698,\n",
" 0.05072581395506859,\n",
" 0.02917262725532055,\n",
" 0.00728880288079381,\n",
" -0.011496285907924175,\n",
" -0.05313197895884514,\n",
" -0.027890320867300034,\n",
" 0.030064044520258904,\n",
" -0.06029842048883438,\n",
" -0.043088313192129135,\n",
" 0.05004483461380005,\n",
" 0.0015685138059780002,\n",
" -0.01834270916879177,\n",
" 0.046504270285367966,\n",
" -0.043405696749687195,\n",
" 0.08440472185611725,\n",
" 0.022881966084241867,\n",
" 0.013790522702038288,\n",
" 0.03525456413626671,\n",
" 0.08282686769962311,\n",
" 0.031224273145198822,\n",
" -0.032255761325359344,\n",
" 0.033190108835697174,\n",
" -0.02879202552139759,\n",
" 0.09641945362091064,\n",
" 0.014541308395564556,\n",
" -0.0425688773393631,\n",
" 0.007836293429136276,\n",
" -0.07434573024511337,\n",
" -0.03844423592090607,\n",
" 0.007907251827418804,\n",
" 0.005604865029454231,\n",
" 0.014666788280010223,\n",
" -0.015787949785590172,\n",
" -0.011632969602942467,\n",
" 0.06502652168273926,\n",
" -0.09462911635637283,\n",
" -0.05418006703257561,\n",
" 0.07266692817211151,\n",
" -0.0059609077870845795,\n",
" 0.015150884166359901,\n",
" 0.033904850482940674,\n",
" 0.04719925299286842,\n",
" -0.006713803857564926,\n",
" -0.0628967359662056,\n",
" 0.2842618525028229,\n",
" -0.007497070357203484,\n",
" 0.11969559639692307,\n",
" 0.047007955610752106,\n",
" -0.023929651826620102,\n",
" 0.015414181165397167,\n",
" -0.029861856251955032,\n",
" -0.014299873262643814,\n",
" 0.018457207828760147,\n",
" 0.05915089696645737,\n",
" -0.03441261500120163,\n",
" -0.01635487750172615,\n",
" -0.021376853808760643,\n",
" -0.01367877796292305,\n",
" -0.04958583787083626,\n",
" 0.01463531143963337,\n",
" 0.01211540587246418,\n",
" -0.005459521431475878,\n",
" 0.005695596802979708,\n",
" -0.04747362807393074,\n",
" -0.05998056381940842,\n",
" -0.042840663343667984,\n",
" 0.04042218253016472,\n",
" -0.006624337285757065,\n",
" 0.025121480226516724,\n",
" -0.054944958537817,\n",
" -0.06516158580780029,\n",
" 0.007337308023124933,\n",
" -6.10322324013495e-33,\n",
" 0.002179093426093459,\n",
" -0.073213592171669,\n",
" -0.014703890308737755,\n",
" 0.00238825217820704,\n",
" 0.02046307921409607,\n",
" -0.06456342339515686,\n",
" 0.014286896213889122,\n",
" 0.02082856185734272,\n",
" -0.07692538946866989,\n",
" 0.09246989339590073,\n",
" -0.03469334542751312,\n",
" 0.022259987890720367,\n",
" -0.0369521826505661,\n",
" -0.0876070111989975,\n",
" 0.13785682618618011,\n",
" -0.000683621852658689,\n",
" 0.0018552240217104554,\n",
" 0.07194776087999344,\n",
" -0.0633404403924942,\n",
" -0.01646324060857296,\n",
" -0.0361541211605072,\n",
" -0.006936112884432077,\n",
" 0.003252814756706357,\n",
" 0.02627389132976532,\n",
" -0.0014277833979576826,\n",
" -0.09001296013593674,\n",
" 0.008833721280097961,\n",
" -0.07455790787935257,\n",
" 0.10064911842346191,\n",
" 0.03227779641747475,\n",
" -0.016069436445832253,\n",
" 0.024673042818903923,\n",
" 0.04188213497400284,\n",
" 0.03961843252182007,\n",
" -0.028469551354646683,\n",
" -0.05262545496225357,\n",
" -0.006966825108975172,\n",
" -0.0033834113273769617,\n",
" -0.038578882813453674,\n",
" -0.010265848599374294,\n",
" -0.033789463341236115,\n",
" 0.0030778711661696434,\n",
" -0.05088731646537781,\n",
" -0.019024258479475975,\n",
" 0.05421010032296181,\n",
" 0.015494044870138168,\n",
" -0.009311210364103317,\n",
" 0.0050599598325788975,\n",
" -0.04918931797146797,\n",
" 0.03970836102962494,\n",
" 0.06579958647489548,\n",
" 0.014110234566032887,\n",
" -0.04829266294836998,\n",
" 0.05065532773733139,\n",
" 0.021345015615224838,\n",
" -0.02805492654442787,\n",
" -0.013115333393216133,\n",
" -0.03833610191941261,\n",
" 0.0081633934751153,\n",
" 0.0020320340991020203,\n",
" 0.025601046159863472,\n",
" 0.046745311468839645,\n",
" -0.07602663338184357,\n",
" 0.08589514344930649,\n",
" -0.09630884975194931,\n",
" 0.01156257651746273,\n",
" 0.047838304191827774,\n",
" -0.03707060590386391,\n",
" 0.05717772990465164,\n",
" -0.028168894350528717,\n",
" -0.06691361963748932,\n",
" 0.003909755032509565,\n",
" -0.01265989150851965,\n",
" -0.024667585268616676,\n",
" -0.04399942606687546,\n",
" 0.013469734229147434,\n",
" 0.013298758305609226,\n",
" 0.0409042127430439,\n",
" -0.012081797234714031,\n",
" -0.009779289364814758,\n",
" 0.021113228052854538,\n",
" -0.06191551312804222,\n",
" -0.010964356362819672,\n",
" 0.027119100093841553,\n",
" -0.03144009783864021,\n",
" 0.037719033658504486,\n",
" 0.02421882562339306,\n",
" -0.13700149953365326,\n",
" 0.0038421833887696266,\n",
" -0.06574120372533798,\n",
" -0.12629178166389465,\n",
" 0.018397213891148567,\n",
" 0.0019562605302780867,\n",
" -0.06581622362136841,\n",
" 0.0056412797421216965,\n",
" 6.17423926197194e-33,\n",
" 0.11609319597482681,\n",
" 0.023075049743056297,\n",
" -0.02540658414363861,\n",
" 0.021112393587827682,\n",
" -0.010050611570477486,\n",
" 0.0045014130882918835,\n",
" 0.02216450683772564,\n",
" 0.03083667904138565,\n",
" -0.065506212413311,\n",
" -0.028498610481619835,\n",
" -0.08708083629608154,\n",
" -0.027195820584893227,\n",
" 0.04075731709599495,\n",
" -0.00738579360768199,\n",
" 0.031747449189424515,\n",
" 0.020246611908078194,\n",
" 0.03285415843129158,\n",
" -0.037579674273729324,\n",
" -0.025780295953154564,\n",
" 0.044498566538095474,\n",
" -0.01600523293018341,\n",
" -0.1110200434923172,\n",
" 0.10275887697935104,\n",
" -0.044455550611019135,\n",
" -0.043082430958747864,\n",
" 0.04361744225025177,\n",
" 0.09388253092765808,\n",
" 0.03668423742055893,\n",
" -0.08740879595279694,\n",
" -0.015174541622400284,\n",
" 0.035617802292108536,\n",
" -0.056008175015449524,\n",
" -0.07729960232973099,\n",
" -0.055068857967853546,\n",
" 0.011802412569522858,\n",
" 0.0005090870545245707,\n",
" -0.04531490430235863,\n",
" 0.009635107591748238,\n",
" 0.0066973078064620495,\n",
" -0.08850639313459396,\n",
" 0.07926266640424728,\n",
" 0.03328349068760872,\n",
" 0.02206319011747837,\n",
" 0.08003410696983337,\n",
" -0.004926585592329502,\n",
" -0.012855191715061665,\n",
" -0.030001787468791008,\n",
" 0.0038301211316138506,\n",
" 0.09513995796442032,\n",
" -0.023254361003637314,\n",
" -0.01384524442255497,\n",
" -0.0006733545451425016,\n",
" 0.004949721973389387,\n",
" -0.03836912661790848,\n",
" -0.0484086349606514,\n",
" -0.04300595819950104,\n",
" -0.03302333503961563,\n",
" -0.011142191477119923,\n",
" -0.021775009110569954,\n",
" 0.009556151926517487,\n",
" -0.014081847853958607,\n",
" 0.01725372113287449,\n",
" -0.002208009362220764,\n",
" 0.043982796370983124,\n",
" -0.12186389416456223,\n",
" -0.03109029121696949,\n",
" -0.0648212656378746,\n",
" -0.03446059674024582,\n",
" -0.0009474779944866896,\n",
" 0.019224559888243675,\n",
" 0.030093936249613762,\n",
" 0.011459640227258205,\n",
" -0.031019125133752823,\n",
" 0.11076018959283829,\n",
" -0.08466918021440506,\n",
" -0.028721166774630547,\n",
" -0.006525673437863588,\n",
" 0.05877530202269554,\n",
" 0.021319882944226265,\n",
" 0.08542844653129578,\n",
" 0.05103899911046028,\n",
" -0.02113465592265129,\n",
" 0.01493119541555643,\n",
" 0.010513859800994396,\n",
" -0.023147936910390854,\n",
" -0.044208601117134094,\n",
" -0.0010544550605118275,\n",
" 0.0656798928976059,\n",
" -0.013098708353936672,\n",
" 0.0029119630344212055,\n",
" 0.03165023773908615,\n",
" 0.06931225955486298,\n",
" -0.02299979329109192,\n",
" 0.022364258766174316,\n",
" -0.04974697157740593,\n",
" -1.3714036128931184e-08,\n",
" -0.0205942764878273,\n",
" 0.047028280794620514,\n",
" -0.032210975885391235,\n",
" 0.049078319221735,\n",
" 0.0394253209233284,\n",
" 0.10298289358615875,\n",
" 0.013628372922539711,\n",
" -0.07071256637573242,\n",
" -0.001111415564082563,\n",
" 0.045793063938617706,\n",
" 0.010663686320185661,\n",
" 0.022661199793219566,\n",
" -0.00039414051570929587,\n",
" 0.04868670925498009,\n",
" 0.08181674033403397,\n",
" -0.06234998628497124,\n",
" -0.017647461965680122,\n",
" -0.05699630081653595,\n",
" -0.035604529082775116,\n",
" -0.002848744625225663,\n",
" -0.07433759421110153,\n",
" 0.05970819666981697,\n",
" -0.03040698915719986,\n",
" -0.03587964177131653,\n",
" -0.05538871884346008,\n",
" -0.007939192466437817,\n",
" -0.015285325236618519,\n",
" 0.08461211621761322,\n",
" 0.01166541874408722,\n",
" 0.03213988244533539,\n",
" 0.05643611401319504,\n",
" 0.2006419152021408,\n",
" -0.07411110401153564,\n",
" -0.018009720370173454,\n",
" 0.016179822385311127,\n",
" -0.0028461480978876352,\n",
" 0.0402149073779583,\n",
" 0.0006247136043384671,\n",
" 0.0006973804556764662,\n",
" 0.09922358393669128,\n",
" -0.029822450131177902,\n",
" -0.005783025175333023,\n",
" -0.0028224103152751923,\n",
" -0.11175407469272614,\n",
" 0.012009709142148495,\n",
" -0.009956827387213707,\n",
" 0.011468647047877312,\n",
" -0.054449401795864105,\n",
" -0.016370657831430435,\n",
" 0.022106735035777092,\n",
" 0.03950563445687294,\n",
" 0.005319684278219938,\n",
" 0.042190469801425934,\n",
" 0.08844445645809174,\n",
" 0.0810166597366333,\n",
" 0.06980433315038681,\n",
" -0.04784897342324257,\n",
" 0.01753094792366028,\n",
" -0.10126522183418274,\n",
" -0.016526369377970695,\n",
" 0.11310216039419174,\n",
" 0.0874418243765831,\n",
" 0.09520682692527771,\n",
" 0.10083616524934769]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query_result = embeddings.embed_query(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "76f1b752",
"metadata": {},
"outputs": [],
"source": [
"doc_results = embeddings.embed_documents([\"foo\"])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "221f2f0e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[0.01623339205980301,\n",
" -0.007662336342036724,\n",
" 0.018606489524245262,\n",
" 0.031968992203474045,\n",
" -0.031003747135400772,\n",
" 0.008777972310781479,\n",
" 0.1594553291797638,\n",
" -0.009521624073386192,\n",
" 0.020200366154313087,\n",
" -0.04545809328556061,\n",
" 0.013985812664031982,\n",
" -0.017674963921308517,\n",
" -0.03616964817047119,\n",
" -0.02194339968264103,\n",
" 0.021387653425335884,\n",
" 0.06459270417690277,\n",
" -0.03659535571932793,\n",
" -0.01213359646499157,\n",
" -0.043666232377290726,\n",
" -0.03515005484223366,\n",
" -0.032629866153001785,\n",
" 0.07834123075008392,\n",
" -0.021041689440608025,\n",
" 0.03372766822576523,\n",
" -0.024157941341400146,\n",
" -0.010767146944999695,\n",
" -0.042864806950092316,\n",
" 0.013539575971662998,\n",
" 0.05039731785655022,\n",
" -0.091956727206707,\n",
" 0.035494621843099594,\n",
" 0.18029741942882538,\n",
" 0.01576363667845726,\n",
" -0.04949156939983368,\n",
" -0.003976485226303339,\n",
" 0.00032106428989209235,\n",
" 0.021849628537893295,\n",
" 0.035368386656045914,\n",
" 0.04185418039560318,\n",
" 0.04899369180202484,\n",
" -0.026651302352547646,\n",
" -0.05650882050395012,\n",
" -0.03276852145791054,\n",
" -0.020723465830087662,\n",
" -0.011230835691094398,\n",
" 0.02798161283135414,\n",
" -0.010538998991250992,\n",
" 0.030317796394228935,\n",
" 0.017697133123874664,\n",
" 0.003633821150287986,\n",
" -0.008708533830940723,\n",
" -0.04946836829185486,\n",
" -0.029903240501880646,\n",
" 0.022750651463866234,\n",
" 0.09276428818702698,\n",
" 0.05072581395506859,\n",
" 0.02917262725532055,\n",
" 0.00728880288079381,\n",
" -0.011496285907924175,\n",
" -0.05313197895884514,\n",
" -0.027890320867300034,\n",
" 0.030064044520258904,\n",
" -0.06029842048883438,\n",
" -0.043088313192129135,\n",
" 0.05004483461380005,\n",
" 0.0015685138059780002,\n",
" -0.01834270916879177,\n",
" 0.046504270285367966,\n",
" -0.043405696749687195,\n",
" 0.08440472185611725,\n",
" 0.022881966084241867,\n",
" 0.013790522702038288,\n",
" 0.03525456413626671,\n",
" 0.08282686769962311,\n",
" 0.031224273145198822,\n",
" -0.032255761325359344,\n",
" 0.033190108835697174,\n",
" -0.02879202552139759,\n",
" 0.09641945362091064,\n",
" 0.014541308395564556,\n",
" -0.0425688773393631,\n",
" 0.007836293429136276,\n",
" -0.07434573024511337,\n",
" -0.03844423592090607,\n",
" 0.007907251827418804,\n",
" 0.005604865029454231,\n",
" 0.014666788280010223,\n",
" -0.015787949785590172,\n",
" -0.011632969602942467,\n",
" 0.06502652168273926,\n",
" -0.09462911635637283,\n",
" -0.05418006703257561,\n",
" 0.07266692817211151,\n",
" -0.0059609077870845795,\n",
" 0.015150884166359901,\n",
" 0.033904850482940674,\n",
" 0.04719925299286842,\n",
" -0.006713803857564926,\n",
" -0.0628967359662056,\n",
" 0.2842618525028229,\n",
" -0.007497070357203484,\n",
" 0.11969559639692307,\n",
" 0.047007955610752106,\n",
" -0.023929651826620102,\n",
" 0.015414181165397167,\n",
" -0.029861856251955032,\n",
" -0.014299873262643814,\n",
" 0.018457207828760147,\n",
" 0.05915089696645737,\n",
" -0.03441261500120163,\n",
" -0.01635487750172615,\n",
" -0.021376853808760643,\n",
" -0.01367877796292305,\n",
" -0.04958583787083626,\n",
" 0.01463531143963337,\n",
" 0.01211540587246418,\n",
" -0.005459521431475878,\n",
" 0.005695596802979708,\n",
" -0.04747362807393074,\n",
" -0.05998056381940842,\n",
" -0.042840663343667984,\n",
" 0.04042218253016472,\n",
" -0.006624337285757065,\n",
" 0.025121480226516724,\n",
" -0.054944958537817,\n",
" -0.06516158580780029,\n",
" 0.007337308023124933,\n",
" -6.10322324013495e-33,\n",
" 0.002179093426093459,\n",
" -0.073213592171669,\n",
" -0.014703890308737755,\n",
" 0.00238825217820704,\n",
" 0.02046307921409607,\n",
" -0.06456342339515686,\n",
" 0.014286896213889122,\n",
" 0.02082856185734272,\n",
" -0.07692538946866989,\n",
" 0.09246989339590073,\n",
" -0.03469334542751312,\n",
" 0.022259987890720367,\n",
" -0.0369521826505661,\n",
" -0.0876070111989975,\n",
" 0.13785682618618011,\n",
" -0.000683621852658689,\n",
" 0.0018552240217104554,\n",
" 0.07194776087999344,\n",
" -0.0633404403924942,\n",
" -0.01646324060857296,\n",
" -0.0361541211605072,\n",
" -0.006936112884432077,\n",
" 0.003252814756706357,\n",
" 0.02627389132976532,\n",
" -0.0014277833979576826,\n",
" -0.09001296013593674,\n",
" 0.008833721280097961,\n",
" -0.07455790787935257,\n",
" 0.10064911842346191,\n",
" 0.03227779641747475,\n",
" -0.016069436445832253,\n",
" 0.024673042818903923,\n",
" 0.04188213497400284,\n",
" 0.03961843252182007,\n",
" -0.028469551354646683,\n",
" -0.05262545496225357,\n",
" -0.006966825108975172,\n",
" -0.0033834113273769617,\n",
" -0.038578882813453674,\n",
" -0.010265848599374294,\n",
" -0.033789463341236115,\n",
" 0.0030778711661696434,\n",
" -0.05088731646537781,\n",
" -0.019024258479475975,\n",
" 0.05421010032296181,\n",
" 0.015494044870138168,\n",
" -0.009311210364103317,\n",
" 0.0050599598325788975,\n",
" -0.04918931797146797,\n",
" 0.03970836102962494,\n",
" 0.06579958647489548,\n",
" 0.014110234566032887,\n",
" -0.04829266294836998,\n",
" 0.05065532773733139,\n",
" 0.021345015615224838,\n",
" -0.02805492654442787,\n",
" -0.013115333393216133,\n",
" -0.03833610191941261,\n",
" 0.0081633934751153,\n",
" 0.0020320340991020203,\n",
" 0.025601046159863472,\n",
" 0.046745311468839645,\n",
" -0.07602663338184357,\n",
" 0.08589514344930649,\n",
" -0.09630884975194931,\n",
" 0.01156257651746273,\n",
" 0.047838304191827774,\n",
" -0.03707060590386391,\n",
" 0.05717772990465164,\n",
" -0.028168894350528717,\n",
" -0.06691361963748932,\n",
" 0.003909755032509565,\n",
" -0.01265989150851965,\n",
" -0.024667585268616676,\n",
" -0.04399942606687546,\n",
" 0.013469734229147434,\n",
" 0.013298758305609226,\n",
" 0.0409042127430439,\n",
" -0.012081797234714031,\n",
" -0.009779289364814758,\n",
" 0.021113228052854538,\n",
" -0.06191551312804222,\n",
" -0.010964356362819672,\n",
" 0.027119100093841553,\n",
" -0.03144009783864021,\n",
" 0.037719033658504486,\n",
" 0.02421882562339306,\n",
" -0.13700149953365326,\n",
" 0.0038421833887696266,\n",
" -0.06574120372533798,\n",
" -0.12629178166389465,\n",
" 0.018397213891148567,\n",
" 0.0019562605302780867,\n",
" -0.06581622362136841,\n",
" 0.0056412797421216965,\n",
" 6.17423926197194e-33,\n",
" 0.11609319597482681,\n",
" 0.023075049743056297,\n",
" -0.02540658414363861,\n",
" 0.021112393587827682,\n",
" -0.010050611570477486,\n",
" 0.0045014130882918835,\n",
" 0.02216450683772564,\n",
" 0.03083667904138565,\n",
" -0.065506212413311,\n",
" -0.028498610481619835,\n",
" -0.08708083629608154,\n",
" -0.027195820584893227,\n",
" 0.04075731709599495,\n",
" -0.00738579360768199,\n",
" 0.031747449189424515,\n",
" 0.020246611908078194,\n",
" 0.03285415843129158,\n",
" -0.037579674273729324,\n",
" -0.025780295953154564,\n",
" 0.044498566538095474,\n",
" -0.01600523293018341,\n",
" -0.1110200434923172,\n",
" 0.10275887697935104,\n",
" -0.044455550611019135,\n",
" -0.043082430958747864,\n",
" 0.04361744225025177,\n",
" 0.09388253092765808,\n",
" 0.03668423742055893,\n",
" -0.08740879595279694,\n",
" -0.015174541622400284,\n",
" 0.035617802292108536,\n",
" -0.056008175015449524,\n",
" -0.07729960232973099,\n",
" -0.055068857967853546,\n",
" 0.011802412569522858,\n",
" 0.0005090870545245707,\n",
" -0.04531490430235863,\n",
" 0.009635107591748238,\n",
" 0.0066973078064620495,\n",
" -0.08850639313459396,\n",
" 0.07926266640424728,\n",
" 0.03328349068760872,\n",
" 0.02206319011747837,\n",
" 0.08003410696983337,\n",
" -0.004926585592329502,\n",
" -0.012855191715061665,\n",
" -0.030001787468791008,\n",
" 0.0038301211316138506,\n",
" 0.09513995796442032,\n",
" -0.023254361003637314,\n",
" -0.01384524442255497,\n",
" -0.0006733545451425016,\n",
" 0.004949721973389387,\n",
" -0.03836912661790848,\n",
" -0.0484086349606514,\n",
" -0.04300595819950104,\n",
" -0.03302333503961563,\n",
" -0.011142191477119923,\n",
" -0.021775009110569954,\n",
" 0.009556151926517487,\n",
" -0.014081847853958607,\n",
" 0.01725372113287449,\n",
" -0.002208009362220764,\n",
" 0.043982796370983124,\n",
" -0.12186389416456223,\n",
" -0.03109029121696949,\n",
" -0.0648212656378746,\n",
" -0.03446059674024582,\n",
" -0.0009474779944866896,\n",
" 0.019224559888243675,\n",
" 0.030093936249613762,\n",
" 0.011459640227258205,\n",
" -0.031019125133752823,\n",
" 0.11076018959283829,\n",
" -0.08466918021440506,\n",
" -0.028721166774630547,\n",
" -0.006525673437863588,\n",
" 0.05877530202269554,\n",
" 0.021319882944226265,\n",
" 0.08542844653129578,\n",
" 0.05103899911046028,\n",
" -0.02113465592265129,\n",
" 0.01493119541555643,\n",
" 0.010513859800994396,\n",
" -0.023147936910390854,\n",
" -0.044208601117134094,\n",
" -0.0010544550605118275,\n",
" 0.0656798928976059,\n",
" -0.013098708353936672,\n",
" 0.0029119630344212055,\n",
" 0.03165023773908615,\n",
" 0.06931225955486298,\n",
" -0.02299979329109192,\n",
" 0.022364258766174316,\n",
" -0.04974697157740593,\n",
" -1.3714036128931184e-08,\n",
" -0.0205942764878273,\n",
" 0.047028280794620514,\n",
" -0.032210975885391235,\n",
" 0.049078319221735,\n",
" 0.0394253209233284,\n",
" 0.10298289358615875,\n",
" 0.013628372922539711,\n",
" -0.07071256637573242,\n",
" -0.001111415564082563,\n",
" 0.045793063938617706,\n",
" 0.010663686320185661,\n",
" 0.022661199793219566,\n",
" -0.00039414051570929587,\n",
" 0.04868670925498009,\n",
" 0.08181674033403397,\n",
" -0.06234998628497124,\n",
" -0.017647461965680122,\n",
" -0.05699630081653595,\n",
" -0.035604529082775116,\n",
" -0.002848744625225663,\n",
" -0.07433759421110153,\n",
" 0.05970819666981697,\n",
" -0.03040698915719986,\n",
" -0.03587964177131653,\n",
" -0.05538871884346008,\n",
" -0.007939192466437817,\n",
" -0.015285325236618519,\n",
" 0.08461211621761322,\n",
" 0.01166541874408722,\n",
" 0.03213988244533539,\n",
" 0.05643611401319504,\n",
" 0.2006419152021408,\n",
" -0.07411110401153564,\n",
" -0.018009720370173454,\n",
" 0.016179822385311127,\n",
" -0.0028461480978876352,\n",
" 0.0402149073779583,\n",
" 0.0006247136043384671,\n",
" 0.0006973804556764662,\n",
" 0.09922358393669128,\n",
" -0.029822450131177902,\n",
" -0.005783025175333023,\n",
" -0.0028224103152751923,\n",
" -0.11175407469272614,\n",
" 0.012009709142148495,\n",
" -0.009956827387213707,\n",
" 0.011468647047877312,\n",
" -0.054449401795864105,\n",
" -0.016370657831430435,\n",
" 0.022106735035777092,\n",
" 0.03950563445687294,\n",
" 0.005319684278219938,\n",
" 0.042190469801425934,\n",
" 0.08844445645809174,\n",
" 0.0810166597366333,\n",
" 0.06980433315038681,\n",
" -0.04784897342324257,\n",
" 0.01753094792366028,\n",
" -0.10126522183418274,\n",
" -0.016526369377970695,\n",
" 0.11310216039419174,\n",
" 0.0874418243765831,\n",
" 0.09520682692527771,\n",
" 0.10083616524934769]]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"doc_results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aaad49f8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "langchain",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"vscode": {
"interpreter": {
"hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}