Files
langchain/docs/modules/document_loaders/examples/sagemaker.ipynb
2023-02-21 17:02:04 -08:00

184 lines
9.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Defaulting to user installation because normal site-packages is not writeable\n",
"Collecting langchain\n",
" Downloading langchain-0.0.80-py3-none-any.whl (222 kB)\n",
"\u001b[K |████████████████████████████████| 222 kB 2.1 MB/s eta 0:00:01\n",
"\u001b[?25hRequirement already satisfied: numpy<2,>=1 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (1.24.1)\n",
"Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (3.8.3)\n",
"Collecting pydantic<2,>=1\n",
" Downloading pydantic-1.10.4-cp39-cp39-macosx_11_0_arm64.whl (2.6 MB)\n",
"\u001b[K |████████████████████████████████| 2.6 MB 3.3 MB/s eta 0:00:01\n",
"\u001b[?25hCollecting SQLAlchemy<2,>=1\n",
" Downloading SQLAlchemy-1.4.46.tar.gz (8.5 MB)\n",
"\u001b[K |████████████████████████████████| 8.5 MB 23.4 MB/s eta 0:00:01\n",
"\u001b[?25hCollecting tenacity<9.0.0,>=8.1.0\n",
" Downloading tenacity-8.2.0-py3-none-any.whl (24 kB)\n",
"Requirement already satisfied: requests<3,>=2 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (2.28.2)\n",
"Requirement already satisfied: PyYAML<7,>=6 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from langchain) (6.0)\n",
"Collecting dataclasses-json<0.6.0,>=0.5.7\n",
" Downloading dataclasses_json-0.5.7-py3-none-any.whl (25 kB)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (4.0.2)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.4)\n",
"Requirement already satisfied: attrs>=17.3.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (22.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.8.2)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
"Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (2.1.1)\n",
"Collecting marshmallow<4.0.0,>=3.3.0\n",
" Downloading marshmallow-3.19.0-py3-none-any.whl (49 kB)\n",
"\u001b[K |████████████████████████████████| 49 kB 26.9 MB/s eta 0:00:01\n",
"\u001b[?25hCollecting marshmallow-enum<2.0.0,>=1.5.1\n",
" Downloading marshmallow_enum-1.5.1-py2.py3-none-any.whl (4.2 kB)\n",
"Collecting typing-inspect>=0.4.0\n",
" Downloading typing_inspect-0.8.0-py3-none-any.whl (8.7 kB)\n",
"Requirement already satisfied: packaging>=17.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from marshmallow<4.0.0,>=3.3.0->dataclasses-json<0.6.0,>=0.5.7->langchain) (23.0)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from pydantic<2,>=1->langchain) (4.4.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain) (1.26.14)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/nmehta/Library/Python/3.9/lib/python/site-packages (from requests<3,>=2->langchain) (2022.12.7)\n",
"Collecting mypy-extensions>=0.3.0\n",
" Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n",
"Building wheels for collected packages: SQLAlchemy\n",
" Building wheel for SQLAlchemy (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for SQLAlchemy: filename=SQLAlchemy-1.4.46-cp39-cp39-macosx_10_9_universal2.whl size=1578667 sha256=9991d70fde083b993d7fe1fd61fca33a279e921f1b8296b02037e24b8cac1097\n",
" Stored in directory: /Users/nmehta/Library/Caches/pip/wheels/3c/99/65/57cf5a0ec6e7f3b803a68d31694501e168997e03e80adc903d\n",
"Successfully built SQLAlchemy\n",
"Installing collected packages: mypy-extensions, marshmallow, typing-inspect, marshmallow-enum, tenacity, SQLAlchemy, pydantic, dataclasses-json, langchain\n",
"\u001b[33m WARNING: The script langchain-server is installed in '/Users/nmehta/Library/Python/3.9/bin' which is not on PATH.\n",
" Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\n",
"Successfully installed SQLAlchemy-1.4.46 dataclasses-json-0.5.7 langchain-0.0.80 marshmallow-3.19.0 marshmallow-enum-1.5.1 mypy-extensions-1.0.0 pydantic-1.10.4 tenacity-8.2.0 typing-inspect-0.8.0\n",
"\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.0 is available.\n",
"You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n",
"Defaulting to user installation because normal site-packages is not writeable\n",
"Collecting html2text\n",
" Downloading html2text-2020.1.16-py3-none-any.whl (32 kB)\n",
"Installing collected packages: html2text\n",
"\u001b[33m WARNING: The script html2text is installed in '/Users/nmehta/Library/Python/3.9/bin' which is not on PATH.\n",
" Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\n",
"Successfully installed html2text-2020.1.16\n",
"\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.0 is available.\n",
"You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n"
]
}
],
"source": [
"!pip3 install langchain"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.docstore.document import Document"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"example_doc_1 = \"\"\"\n",
"Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.\n",
"Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.\n",
"Therefore, Peter stayed with her at the hospital for 3 days without leaving.\n",
"\"\"\"\n",
"\n",
"docs = [\n",
" Document(\n",
" page_content=example_doc_1,\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'output_text': '3 days'}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain import PromptTemplate, HuggingFaceHub, LLMChain, SagemakerEndpoint\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"import json\n",
"\n",
"query = \"\"\"How long was Elizabeth hospitalized?\n",
"\"\"\"\n",
"\n",
"prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n",
"\n",
"{context}\n",
"\n",
"Question: {question}\n",
"Answer:\"\"\"\n",
"PROMPT = PromptTemplate(\n",
" template=prompt_template, input_variables=[\"context\", \"question\"]\n",
")\n",
"\n",
"def model_input_transform_fn(prompt, model_kwargs):\n",
" parameter_payload = {\"inputs\": prompt, \"parameters\": model_kwargs}\n",
" return json.dumps(parameter_payload).encode(\"utf-8\") \n",
"\n",
"chain = load_qa_chain(llm=SagemakerEndpoint(\n",
" endpoint_name=\"my-sagemaker-model-endpoint\", \n",
" credentials_profile_name=\"credentials-profile-name\", \n",
" region_name=\"us-west-2\", \n",
" model_kwargs={\"temperature\":1e-10},\n",
" content_type=\"application/json\", \n",
" model_input_transform_fn=model_input_transform_fn), \n",
" prompt=PROMPT) \n",
"\n",
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}