mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-18 04:25:22 +00:00
Compare commits
19 Commits
rlm/LLaMA2
...
v0.0.318
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76d3afaef0 | ||
|
|
5dd2161c4b | ||
|
|
720ecacb1c | ||
|
|
8425f33363 | ||
|
|
4adabd33ac | ||
|
|
c9f1768cb9 | ||
|
|
84d250f781 | ||
|
|
7db6aabf65 | ||
|
|
ed62984cb2 | ||
|
|
f818ec49b8 | ||
|
|
1da6d92369 | ||
|
|
a6b483dcbc | ||
|
|
008c7df80d | ||
|
|
77fc2f7644 | ||
|
|
2661dc94f3 | ||
|
|
4b6fdd7bf0 | ||
|
|
2038c7fd5d | ||
|
|
dfb4baa3f9 | ||
|
|
12f8e87a0e |
@@ -40,6 +40,7 @@ Notebook | Description
|
||||
[openai_functions_retrieval_qa....](https://github.com/langchain-ai/langchain/tree/master/cookbook/openai_functions_retrieval_qa.ipynb) | Structure response output in a question answering system by incorporating openai functions into a retrieval pipeline.
|
||||
[petting_zoo.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/petting_zoo.ipynb) | Create multi-agent simulations with simulated environments using the petting zoo library.
|
||||
[plan_and_execute_agent.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/plan_and_execute_agent.ipynb) | Create plan-and-execute agents that accomplish objectives by planning tasks with a language model (llm) and executing them with a separate agent.
|
||||
[press_releases.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/press_releases.ipynb) | Retrieve and query company press release data powered by [Kay.ai](https://kay.ai).
|
||||
[program_aided_language_model.i...](https://github.com/langchain-ai/langchain/tree/master/cookbook/program_aided_language_model.ipynb) | Implement program-aided language models as described in the provided research paper.
|
||||
[sales_agent_with_context.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/sales_agent_with_context.ipynb) | Implement a context-aware ai sales agent, salesgpt, that can have natural sales conversations, interact with other systems, and use a product knowledge base to discuss a company's offerings.
|
||||
[self_query_hotel_search.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/self_query_hotel_search.ipynb) | Build a hotel room search feature with self-querying retrieval, using a specific hotel recommendation dataset.
|
||||
|
||||
152
cookbook/press_releases.ipynb
Normal file
152
cookbook/press_releases.ipynb
Normal file
@@ -0,0 +1,152 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "62ee82e4-2ad8-498b-8438-fac388afe1a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Press Releases Data\n",
|
||||
"=\n",
|
||||
"\n",
|
||||
"Press Releases data powered by [Kay.ai](https://kay.ai).\n",
|
||||
"\n",
|
||||
">Press releases are used by companies to announce something noteworthy, including product launches, financial performance reports, partnerships, and other significant news. They are widely used by analysts to track corporate strategy, operational updates and financial performance.\n",
|
||||
"Kay.ai obtains press releases of all US public companies from a variety of sources, which include the company's official press room and partnerships with various data API providers. \n",
|
||||
"This data is updated till Sept 30th for free access, if you want to access the real-time feed, reach out to us at hello@kay.ai or [tweet at us](https://twitter.com/vishalrohra_)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8183d85d-365f-4672-a963-52b533547de0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Setup\n",
|
||||
"=\n",
|
||||
"\n",
|
||||
"First you will need to install the `kay` package. You will also need an API key: you can get one for free at [https://kay.ai](https://kay.ai/). Once you have an API key, you must set it as an environment variable `KAY_API_KEY`.\n",
|
||||
"\n",
|
||||
"In this example we're going to use the `KayAiRetriever`. Take a look at the [kay notebook](/docs/integrations/retrievers/kay) for more detailed information for the parmeters that it accepts."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "02ec21c7-49fe-4844-b58a-bf064ad40b2a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Examples\n",
|
||||
"="
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "bf0395f7-6ebe-4136-8b0d-00b9dea3becd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdin",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" ········\n",
|
||||
" ········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Setup API keys for Kay and OpenAI\n",
|
||||
"from getpass import getpass\n",
|
||||
"KAY_API_KEY = getpass()\n",
|
||||
"OPENAI_API_KEY = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "f7fcaf70-29a4-444b-8f07-9784f808c300",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ[\"KAY_API_KEY\"] = KAY_API_KEY\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ac00bf93-3635-4ffe-b9a6-a8b4f35c0c85",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import ConversationalRetrievalChain\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.retrievers import KayAiRetriever\n",
|
||||
"\n",
|
||||
"model = ChatOpenAI(model_name=\"gpt-3.5-turbo\")\n",
|
||||
"retriever = KayAiRetriever.create(dataset_id=\"company\", data_types=[\"PressRelease\"], num_contexts=6)\n",
|
||||
"qa = ConversationalRetrievalChain.from_llm(model, retriever=retriever)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "8d9d927c-35b2-4a7b-8ea7-4d0350797941",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"-> **Question**: How is the healthcare industry adopting generative AI tools? \n",
|
||||
"\n",
|
||||
"**Answer**: The healthcare industry is adopting generative AI tools to improve various aspects of patient care and administrative tasks. Companies like HCA Healthcare Inc, Amazon Com Inc, and Mayo Clinic have collaborated with technology providers like Google Cloud, AWS, and Microsoft to implement generative AI solutions.\n",
|
||||
"\n",
|
||||
"HCA Healthcare is testing a nurse handoff tool that generates draft reports quickly and accurately, which nurses have shown interest in using. They are also exploring the use of Google's medically-tuned Med-PaLM 2 LLM to support caregivers in asking complex medical questions.\n",
|
||||
"\n",
|
||||
"Amazon Web Services (AWS) has introduced AWS HealthScribe, a generative AI-powered service that automatically creates clinical documentation. However, integrating multiple AI systems into a cohesive solution requires significant engineering resources, including access to AI experts, healthcare data, and compute capacity.\n",
|
||||
"\n",
|
||||
"Mayo Clinic is among the first healthcare organizations to deploy Microsoft 365 Copilot, a generative AI service that combines large language models with organizational data from Microsoft 365. This tool has the potential to automate tasks like form-filling, relieving administrative burdens on healthcare providers and allowing them to focus more on patient care.\n",
|
||||
"\n",
|
||||
"Overall, the healthcare industry is recognizing the potential benefits of generative AI tools in improving efficiency, automating tasks, and enhancing patient care. \n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# More sample questions in the Playground on https://kay.ai\n",
|
||||
"questions = [\n",
|
||||
" \"How is the healthcare industry adopting generative AI tools?\",\n",
|
||||
" #\"What are some recent challenges faced by the renewable energy sector?\",\n",
|
||||
"]\n",
|
||||
"chat_history = []\n",
|
||||
"\n",
|
||||
"for question in questions:\n",
|
||||
" result = qa({\"question\": question, \"chat_history\": chat_history})\n",
|
||||
" chat_history.append((question, result[\"answer\"]))\n",
|
||||
" print(f\"-> **Question**: {question} \\n\")\n",
|
||||
" print(f\"**Answer**: {result['answer']} \\n\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -229,7 +229,7 @@
|
||||
"- fasttext (recommended)\n",
|
||||
"- langdetect\n",
|
||||
"\n",
|
||||
"From our exprience *fasttext* performs a bit better, but you should verify it on your use case."
|
||||
"From our experience *fasttext* performs a bit better, but you should verify it on your use case."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
"\n",
|
||||
"In this notebook, we will look at building a basic system for question answering, based on private data. Before feeding the LLM with this data, we need to protect it so that it doesn't go to an external API (e.g. OpenAI, Anthropic). Then, after receiving the model output, we would like the data to be restored to its original form. Below you can observe an example flow of this QA system:\n",
|
||||
"\n",
|
||||
"<img src=\"/img/qa_privacy_protection.png\" width=\"800\"/>\n",
|
||||
"<img src=\"/img/qa_privacy_protection.png\" width=\"900\"/>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"In the following notebook, we will not go into the details of how the anonymizer works. If you are interested, please visit [this part of the documentation](https://python.langchain.com/docs/guides/privacy/presidio_data_anonymization/).\n",
|
||||
@@ -839,6 +839,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"documents = [Document(page_content=document_content)]\n",
|
||||
"\n",
|
||||
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)\n",
|
||||
"chunks = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
|
||||
121
docs/docs/integrations/chat/pai_eas_chat_endpoint.ipynb
Normal file
121
docs/docs/integrations/chat/pai_eas_chat_endpoint.ipynb
Normal file
@@ -0,0 +1,121 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AliCloud PAI EAS\n",
|
||||
"Machine Learning Platform for AI of Alibaba Cloud is a machine learning or deep learning engineering platform intended for enterprises and developers. It provides easy-to-use, cost-effective, high-performance, and easy-to-scale plug-ins that can be applied to various industry scenarios. With over 140 built-in optimization algorithms, Machine Learning Platform for AI provides whole-process AI engineering capabilities including data labeling (PAI-iTAG), model building (PAI-Designer and PAI-DSW), model training (PAI-DLC), compilation optimization, and inference deployment (PAI-EAS). PAI-EAS supports different types of hardware resources, including CPUs and GPUs, and features high throughput and low latency. It allows you to deploy large-scale complex models with a few clicks and perform elastic scale-ins and scale-outs in real time. It also provides a comprehensive O&M and monitoring system."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup Eas Service\n",
|
||||
"\n",
|
||||
"One who want to use eas llms must set up eas service first. When the eas service is launched, eas_service_rul and eas_service token can be got. Users can refer to https://www.alibabacloud.com/help/en/pai/user-guide/service-deployment/ for more information. Try to set environment variables to init eas service url and token:\n",
|
||||
"\n",
|
||||
"```base\n",
|
||||
"export EAS_SERVICE_URL=XXX\n",
|
||||
"export EAS_SERVICE_TOKEN=XXX\n",
|
||||
"```\n",
|
||||
"or run as follow codes:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from langchain.chat_models.base import HumanMessage\n",
|
||||
"from langchain.chat_models import PaiEasChatEndpoint\n",
|
||||
"os.environ[\"EAS_SERVICE_URL\"] = \"Your_EAS_Service_URL\"\n",
|
||||
"os.environ[\"EAS_SERVICE_TOKEN\"] = \"Your_EAS_Service_Token\"\n",
|
||||
"chat = PaiEasChatEndpoint(\n",
|
||||
" eas_service_url=os.environ[\"EAS_SERVICE_URL\"], \n",
|
||||
" eas_service_token=os.environ[\"EAS_SERVICE_TOKEN\"]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run Chat Model\n",
|
||||
"You can use the default settings to call eas service as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"output = chat([HumanMessage(content=\"write a funny joke\")])\n",
|
||||
"print(\"output:\", output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Or, call eas service with new inference params:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kwargs = {\"temperature\": 0.8, \"top_p\": 0.8, \"top_k\": 5}\n",
|
||||
"output = chat([HumanMessage(content=\"write a funny joke\")], **kwargs)\n",
|
||||
"print(\"output:\", output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Or, run a stream call to get a stream response:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"outputs = chat.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
|
||||
"for output in outputs:\n",
|
||||
" print(\"stream output:\", output)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
93
docs/docs/integrations/llms/pai_eas_endpoint.ipynb
Normal file
93
docs/docs/integrations/llms/pai_eas_endpoint.ipynb
Normal file
@@ -0,0 +1,93 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AliCloud PAI EAS\n",
|
||||
"Machine Learning Platform for AI of Alibaba Cloud is a machine learning or deep learning engineering platform intended for enterprises and developers. It provides easy-to-use, cost-effective, high-performance, and easy-to-scale plug-ins that can be applied to various industry scenarios. With over 140 built-in optimization algorithms, Machine Learning Platform for AI provides whole-process AI engineering capabilities including data labeling (PAI-iTAG), model building (PAI-Designer and PAI-DSW), model training (PAI-DLC), compilation optimization, and inference deployment (PAI-EAS). PAI-EAS supports different types of hardware resources, including CPUs and GPUs, and features high throughput and low latency. It allows you to deploy large-scale complex models with a few clicks and perform elastic scale-ins and scale-outs in real time. It also provides a comprehensive O&M and monitoring system."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms.pai_eas_endpoint import PaiEasEndpoint\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"\n",
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One who want to use eas llms must set up eas service first. When the eas service is launched, eas_service_rul and eas_service token can be got. Users can refer to https://www.alibabacloud.com/help/en/pai/user-guide/service-deployment/ for more information,"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ[\"EAS_SERVICE_URL\"] = \"Your_EAS_Service_URL\"\n",
|
||||
"os.environ[\"EAS_SERVICE_TOKEN\"] = \"Your_EAS_Service_Token\"\n",
|
||||
"llm = PaiEasEndpoint(eas_service_url=os.environ[\"EAS_SERVICE_URL\"], eas_service_token=os.environ[\"EAS_SERVICE_TOKEN\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' Thank you for asking! However, I must respectfully point out that the question contains an error. Justin Bieber was born in 1994, and the Super Bowl was first played in 1967. Therefore, it is not possible for any NFL team to have won the Super Bowl in the year Justin Bieber was born.\\n\\nI hope this clarifies things! If you have any other questions, please feel free to ask.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"\n",
|
||||
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
|
||||
"llm_chain.run(question)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,272 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Google Cloud Enterprise Search\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"[Enterprise Search](https://cloud.google.com/enterprise-search) is a part of the Generative AI App Builder suite of tools offered by Google Cloud.\n",
|
||||
"\n",
|
||||
"Gen AI App Builder lets developers, even those with limited machine learning skills, quickly and easily tap into the power of Google’s foundation models, search expertise, and conversational AI technologies to create enterprise-grade generative AI applications. \n",
|
||||
"\n",
|
||||
"Enterprise Search lets organizations quickly build generative AI powered search engines for customers and employees.Enterprise Search is underpinned by a variety of Google Search technologies, including semantic search, which helps deliver more relevant results than traditional keyword-based search techniques by using natural language processing and machine learning techniques to infer relationships within the content and intent from the user’s query input. Enterprise Search also benefits from Google’s expertise in understanding how users search and factors in content relevance to order displayed results. \n",
|
||||
"\n",
|
||||
"Google Cloud offers Enterprise Search via Gen App Builder in Google Cloud Console and via an API for enterprise workflow integration. \n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to configure Enterprise Search and use the Enterprise Search retriever. The Enterprise Search retriever encapsulates the [Generative AI App Builder Python client library](https://cloud.google.com/generative-ai-app-builder/docs/libraries#client-libraries-install-python) and uses it to access the Enterprise Search [Search Service API](https://cloud.google.com/python/docs/reference/discoveryengine/latest/google.cloud.discoveryengine_v1beta.services.search_service)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Install pre-requisites\n",
|
||||
"\n",
|
||||
"You need to install the `google-cloud-discoverengine` package to use the Enterprise Search retriever."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install google-cloud-discoveryengine"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configure access to Google Cloud and Google Cloud Enterprise Search\n",
|
||||
"\n",
|
||||
"Enterprise Search is generally available for the allowlist (which means customers need to be approved for access) as of June 6, 2023. Contact your Google Cloud sales team for access and pricing details. We are previewing additional features that are coming soon to the generally available offering as part of our [Trusted Tester](https://cloud.google.com/ai/earlyaccess/join?hl=en) program. Sign up for [Trusted Tester](https://cloud.google.com/ai/earlyaccess/join?hl=en) and contact your Google Cloud sales team for an expedited trial.\n",
|
||||
"\n",
|
||||
"Before you can run this notebook you need to:\n",
|
||||
"- Set or create a Google Cloud project and turn on Gen App Builder\n",
|
||||
"- Create and populate an unstructured data store\n",
|
||||
"- Set credentials to access `Enterprise Search API`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Set or create a Google Cloud poject and turn on Gen App Builder\n",
|
||||
"\n",
|
||||
"Follow the instructions in the [Enterprise Search Getting Started guide](https://cloud.google.com/generative-ai-app-builder/docs/before-you-begin) to set/create a GCP project and enable Gen App Builder.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create and populate an unstructured data store\n",
|
||||
"\n",
|
||||
"[Use Google Cloud Console to create an unstructured data store](https://cloud.google.com/generative-ai-app-builder/docs/create-engine-es#unstructured-data) and populate it with the example PDF documents from the `gs://cloud-samples-data/gen-app-builder/search/alphabet-investor-pdfs` Cloud Storage folder. Make sure to use the `Cloud Storage (without metadata)` option."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Set credentials to access Enterprise Search API\n",
|
||||
"\n",
|
||||
"The [Gen App Builder client libraries](https://cloud.google.com/generative-ai-app-builder/docs/libraries) used by the Enterprise Search retriever provide high-level language support for authenticating to Gen App Builder programmatically. Client libraries support [Application Default Credentials (ADC)](https://cloud.google.com/docs/authentication/application-default-credentials); the libraries look for credentials in a set of defined locations and use those credentials to authenticate requests to the API. With ADC, you can make credentials available to your application in a variety of environments, such as local development or production, without needing to modify your application code.\n",
|
||||
"\n",
|
||||
"If running in [Google Colab](https://colab.google) authenticate with `google.colab.google.auth` otherwise follow one of the [supported methods](https://cloud.google.com/docs/authentication/application-default-credentials) to make sure that you Application Default Credentials are properly set."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"if \"google.colab\" in sys.modules:\n",
|
||||
" from google.colab import auth as google_auth\n",
|
||||
"\n",
|
||||
" google_auth.authenticate_user()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configure and use the Enterprise Search retriever\n",
|
||||
"\n",
|
||||
"The Enterprise Search retriever is implemented in the `langchain.retriever.GoogleCloudEntepriseSearchRetriever` class. The `get_relevant_documents` method returns a list of `langchain.schema.Document` documents where the `page_content` field of each document is populated the document content.\n",
|
||||
"Depending on the data type used in Enterprise search (structured or unstructured) the `page_content` field is populated as follows:\n",
|
||||
"- Structured data source: either an `extractive segment` or an `extractive answer` that matches a query. The `metadata` field is populated with metadata (if any) of the document from which the segments or answers were extracted.\n",
|
||||
"- Unstructured data source: a string json containing all the fields returned from the structured data source. The `metadata` field is populated with metadata (if any) of the document \n",
|
||||
"\n",
|
||||
"### Only for Unstructured data sources:\n",
|
||||
"An extractive answer is verbatim text that is returned with each search result. It is extracted directly from the original document. Extractive answers are typically displayed near the top of web pages to provide an end user with a brief answer that is contextually relevant to their query. Extractive answers are available for website and unstructured search.\n",
|
||||
"\n",
|
||||
"An extractive segment is verbatim text that is returned with each search result. An extractive segment is usually more verbose than an extractive answer. Extractive segments can be displayed as an answer to a query, and can be used to perform post-processing tasks and as input for large language models to generate answers or new text. Extractive segments are available for unstructured search.\n",
|
||||
"\n",
|
||||
"For more information about extractive segments and extractive answers refer to [product documentation](https://cloud.google.com/generative-ai-app-builder/docs/snippets).\n",
|
||||
"\n",
|
||||
"When creating an instance of the retriever you can specify a number of parameters that control which Enterprise data store to access and how a natural language query is processed, including configurations for extractive answers and segments.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### The mandatory parameters are:\n",
|
||||
"\n",
|
||||
"- `project_id` - Your Google Cloud PROJECT_ID\n",
|
||||
"- `search_engine_id` - The ID of the data store you want to use. \n",
|
||||
"\n",
|
||||
"The `project_id` and `search_engine_id` parameters can be provided explicitly in the retriever's constructor or through the environment variables - `PROJECT_ID` and `SEARCH_ENGINE_ID`.\n",
|
||||
"\n",
|
||||
"You can also configure a number of optional parameters, including:\n",
|
||||
"\n",
|
||||
"- `max_documents` - The maximum number of documents used to provide extractive segments or extractive answers\n",
|
||||
"- `get_extractive_answers` - By default, the retriever is configured to return extractive segments. Set this field to `True` to return extractive answers. This is used only when `engine_data_type` set to 0 (unstructured) \n",
|
||||
"- `max_extractive_answer_count` - The maximum number of extractive answers returned in each search result.\n",
|
||||
" At most 5 answers will be returned. This is used only when `engine_data_type` set to 0 (unstructured) \n",
|
||||
"- `max_extractive_segment_count` - The maximum number of extractive segments returned in each search result.\n",
|
||||
" Currently one segment will be returned. This is used only when `engine_data_type` set to 0 (unstructured) \n",
|
||||
"- `filter` - The filter expression that allows you filter the search results based on the metadata associated with the documents in the searched data store. \n",
|
||||
"- `query_expansion_condition` - Specification to determine under which conditions query expansion should occur.\n",
|
||||
" 0 - Unspecified query expansion condition. In this case, server behavior defaults to disabled.\n",
|
||||
" 1 - Disabled query expansion. Only the exact search query is used, even if SearchResponse.total_size is zero.\n",
|
||||
" 2 - Automatic query expansion built by the Search API.\n",
|
||||
"- `engine_data_type` - Defines the enterprise search data type\n",
|
||||
" 0 - Unstructured data \n",
|
||||
" 1 - Structured data\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Configure and use the retriever for **unstructured** data with extractve segments "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.retrievers import GoogleCloudEnterpriseSearchRetriever\n",
|
||||
"\n",
|
||||
"PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n",
|
||||
"SEARCH_ENGINE_ID = \"<YOUR SEARCH ENGINE ID>\" # Set to your data store ID"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = GoogleCloudEnterpriseSearchRetriever(\n",
|
||||
" project_id=PROJECT_ID,\n",
|
||||
" search_engine_id=SEARCH_ENGINE_ID,\n",
|
||||
" max_documents=3,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What are Alphabet's Other Bets?\"\n",
|
||||
"\n",
|
||||
"result = retriever.get_relevant_documents(query)\n",
|
||||
"for doc in result:\n",
|
||||
" print(doc)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Configure and use the retriever for **unstructured** data with extractve answers "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = GoogleCloudEnterpriseSearchRetriever(\n",
|
||||
" project_id=PROJECT_ID,\n",
|
||||
" search_engine_id=SEARCH_ENGINE_ID,\n",
|
||||
" max_documents=3,\n",
|
||||
" max_extractive_answer_count=3,\n",
|
||||
" get_extractive_answers=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What are Alphabet's Other Bets?\"\n",
|
||||
"\n",
|
||||
"result = retriever.get_relevant_documents(query)\n",
|
||||
"for doc in result:\n",
|
||||
" print(doc)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Configure and use the retriever for **structured** data with extractve answers "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = GoogleCloudEnterpriseSearchRetriever(\n",
|
||||
" project_id=PROJECT_ID,\n",
|
||||
" search_engine_id=SEARCH_ENGINE_ID,\n",
|
||||
" max_documents=3,\n",
|
||||
" engine_data_type=1\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"result = retriever.get_relevant_documents(query)\n",
|
||||
"for doc in result:\n",
|
||||
" print(doc)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -30,7 +30,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"! pip install google-cloud-discoveryengine"
|
||||
"! pip install google-cloud-discoveryengine\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -80,7 +80,7 @@
|
||||
"if \"google.colab\" in sys.modules:\n",
|
||||
" from google.colab import auth as google_auth\n",
|
||||
"\n",
|
||||
" google_auth.authenticate_user()"
|
||||
" google_auth.authenticate_user()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -90,12 +90,13 @@
|
||||
"## Configure and use the Vertex AI Search retriever\n",
|
||||
"\n",
|
||||
"The Vertex AI Search retriever is implemented in the `langchain.retriever.GoogleVertexAISearchRetriever` class. The `get_relevant_documents` method returns a list of `langchain.schema.Document` documents where the `page_content` field of each document is populated the document content.\n",
|
||||
"Depending on the data type used in Vertex AI Search (structured or unstructured) the `page_content` field is populated as follows:\n",
|
||||
"Depending on the data type used in Vertex AI Search (website, structured or unstructured) the `page_content` field is populated as follows:\n",
|
||||
"\n",
|
||||
"- Structured data source: either an `extractive segment` or an `extractive answer` that matches a query. The `metadata` field is populated with metadata (if any) of the document from which the segments or answers were extracted.\n",
|
||||
"- Unstructured data source: a string json containing all the fields returned from the structured data source. The `metadata` field is populated with metadata (if any) of the document\n",
|
||||
"- Website with advanced indexing: an `extractive answer` that matches a query. The `metadata` field is populated with metadata (if any) of the document from which the segments or answers were extracted.\n",
|
||||
"- Unstructured data source: either an `extractive segment` or an `extractive answer` that matches a query. The `metadata` field is populated with metadata (if any) of the document from which the segments or answers were extracted.\n",
|
||||
"- Structured data source: a string json containing all the fields returned from the structured data source. The `metadata` field is populated with metadata (if any) of the document\n",
|
||||
"\n",
|
||||
"### Only for Unstructured data sources:\n",
|
||||
"### Extractive answers & extractive segments\n",
|
||||
"\n",
|
||||
"An extractive answer is verbatim text that is returned with each search result. It is extracted directly from the original document. Extractive answers are typically displayed near the top of web pages to provide an end user with a brief answer that is contextually relevant to their query. Extractive answers are available for website and unstructured search.\n",
|
||||
"\n",
|
||||
@@ -136,6 +137,7 @@
|
||||
"- `engine_data_type` - Defines the Vertex AI Search data type\n",
|
||||
" - `0` - Unstructured data\n",
|
||||
" - `1` - Structured data\n",
|
||||
" - `2` - Website data with [Advanced Website Indexing](https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing)\n",
|
||||
"\n",
|
||||
"### Migration guide for `GoogleCloudEnterpriseSearchRetriever`\n",
|
||||
"\n",
|
||||
@@ -165,7 +167,7 @@
|
||||
"\n",
|
||||
"PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n",
|
||||
"LOCATION_ID = \"<YOUR LOCATION>\" # Set to your data store location\n",
|
||||
"DATA_STORE_ID = \"<YOUR DATA STORE ID>\" # Set to your data store ID"
|
||||
"DATA_STORE_ID = \"<YOUR DATA STORE ID>\" # Set to your data store ID\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -179,7 +181,7 @@
|
||||
" location_id=LOCATION_ID,\n",
|
||||
" data_store_id=DATA_STORE_ID,\n",
|
||||
" max_documents=3,\n",
|
||||
")"
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -192,7 +194,7 @@
|
||||
"\n",
|
||||
"result = retriever.get_relevant_documents(query)\n",
|
||||
"for doc in result:\n",
|
||||
" print(doc)"
|
||||
" print(doc)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -219,7 +221,7 @@
|
||||
"\n",
|
||||
"result = retriever.get_relevant_documents(query)\n",
|
||||
"for doc in result:\n",
|
||||
" print(doc)"
|
||||
" print(doc)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -245,21 +247,44 @@
|
||||
"\n",
|
||||
"result = retriever.get_relevant_documents(query)\n",
|
||||
"for doc in result:\n",
|
||||
" print(doc)"
|
||||
" print(doc)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Configure and use the retrieve for multi-turn search"
|
||||
"### Configure and use the retriever for **website** data with Advanced Website Indexing\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = GoogleVertexAISearchRetriever(\n",
|
||||
" project_id=PROJECT_ID,\n",
|
||||
" location_id=LOCATION_ID,\n",
|
||||
" data_store_id=DATA_STORE_ID,\n",
|
||||
" max_documents=3,\n",
|
||||
" max_extractive_answer_count=3,\n",
|
||||
" get_extractive_answers=True,\n",
|
||||
" engine_data_type=2,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"result = retriever.get_relevant_documents(query)\n",
|
||||
"for doc in result:\n",
|
||||
" print(doc)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Search with follow-ups is [based](https://cloud.google.com/generative-ai-app-builder/docs/multi-turn-search) on generative AI models and it is different from the regular unstructured data search."
|
||||
"### Configure and use the retriever for multi-turn search\n",
|
||||
"\n",
|
||||
"[Search with follow-ups](https://cloud.google.com/generative-ai-app-builder/docs/multi-turn-search) is based on generative AI models and it is different from the regular unstructured data search.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -276,7 +301,7 @@
|
||||
"\n",
|
||||
"result = retriever.get_relevant_documents(query)\n",
|
||||
"for doc in result:\n",
|
||||
" print(doc)"
|
||||
" print(doc)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -64,13 +64,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 3,
|
||||
"id": "b4d4d386-2a6b-4942-863e-9202f5a9f1d6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.retrievers import KayAiRetriever\n",
|
||||
"import os\n",
|
||||
"from langchain.retrievers import KayAiRetriever\n",
|
||||
"from kay.rag.retrievers import KayRetriever\n",
|
||||
"os.environ[\"KAY_API_KEY\"] = KAY_API_KEY\n",
|
||||
"retriever = KayAiRetriever.create(dataset_id=\"company\", data_types=[\"10-K\", \"10-Q\", \"PressRelease\"], num_contexts=3)\n",
|
||||
@@ -79,19 +79,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 4,
|
||||
"id": "04ee2d6b-c2ab-4e15-8a8b-afaf6ef8c0f6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Company Name: ROKU INC\\nCompany Industry: CABLE & OTHER PAY TELEVISION SERVICES\\nArticle Title: Roku and FreeWheel Announce Strategic Partnership to Bring Roku’s Leading Ad Tech to FreeWheel Customers\\nText: Additionally, eMarketer Link: https://cts.businesswire.com/ct/CT?id=smartlink&url=https%3A%2F%2Fwww.insiderintelligence.com%2Finsights%2Favod-more-than-50-percent-of-us-digital-video-viewers%2F&esheet=53451144&newsitemid=20230712907788&lan=en-US&anchor=eMarketer&index=4&md5=b64dea72bcf6b6379474462602781d83 projects 57% of U.S. digital video users will stream an advertising-based video on demand (AVOD) service this year.\\nHaving solutions aimed at driving greater interoperability and automation will help accelerate this growth.\\nKey highlights of this collaboration include:\\nStreamlined Integration: Roku has now integrated its demand application programming interface (dAPI) with FreeWheel s TV platform. Roku s demand API gives publishers direct, automatic and real-time access to more advertiser demand. This enhanced integration allows for streamlined ad operation workflows and better inventory quality control, both of which will improve publisher yield and revenue.\\nSeamless Data Targeting: Publishers can now use Roku platform signals to enable advertisers to target audiences and measure campaign performance without relying on cookies. Additionally, FreeWheel and Roku will rely on data clean room technology to enable the activation of additional data sets providing better measurement and monetization to publishers and agencies.', metadata={'_additional': {'id': '962b79e0-f9d1-43ae-9f7a-8a9b42bc7a9a'}, 'chunk_type': 'text', 'chunk_years_mentioned': [], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': 'PressRelease', 'data_source_link': 'https://www.nasdaq.com/press-release/roku-and-freewheel-announce-strategic-partnership-to-bring-rokus-leading-ad-tech-to', 'data_source_publish_date': '2023-07-12T00:00:00Z', 'data_source_uid': 'a46f309c-705d-3946-96db-87aa4e73261f', 'title': 'ROKU INC | Roku and FreeWheel Announce Strategic Partnership to Bring Roku’s Leading Ad Tech to FreeWheel Customers'}),\n",
|
||||
" Document(page_content='Company Name: ROKU INC \\n Company Industry: CABLE & OTHER PAY TELEVISION SERVICES \\n Form Title: 10-K 2022-FY \\n Form Section: Risk Factors \\n Text: nd the Note Regarding Forward Looking Statements.This section of this Annual Report generally discusses fiscal years 2022 and 2021 and year to year comparisons between those years.Discussions of fiscal year 2020 and year to year comparisons between fiscal years 2021 and 2020 that are not included in this Annual Report can be found in Management\\'s Discussion and Analysis of Financial Condition and Results of Operations in Part II, Item 7 of our Annual Report for the fiscal year ended December 31, 2021 filed with the SEC on February 18, 2022.Overview Effective as of the fourth quarter of fiscal 2022, we reorganized our reportable segments to better align with management\\'s reporting of information reviewed by the Chief Operating Decision Maker (\"CODM\") for each segment.We renamed our \"player\" segment to \"devices\" which now includes our licensing arrangements with service operators and licensed Roku TV partners in addition to sales of our streaming players, audio products, smart home products and Roku branded TVs that will be designed, made, and sold by us in 2023.Our historical segment information is recast to conform to our new presentation in our financial statements and accompanying notes included in Item 8 of this Annual Report.Our two reportable segments are the platform segment and the devices segment.', metadata={'_additional': {'id': 'a76c5fed-5d63-45a7-b63a-2c30e05140fc'}, 'chunk_type': 'text', 'chunk_years_mentioned': [2020, 2021, 2022, 2023], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': '10-K', 'data_source_link': 'https://www.sec.gov/Archives/edgar/data/1428439/000142843923000007', 'data_source_publish_date': '2022-01-01T00:00:00Z', 'data_source_uid': '0001428439-23-000007', 'title': 'ROKU INC | 10-K 2022-FY '}),\n",
|
||||
" Document(page_content='Company Name: ROKU INC \\n Company Industry: CABLE & OTHER PAY TELEVISION SERVICES \\n Form Title: 10-Q 2023-Q1 \\n Form Section: Risk Factors \\n Text: Our current and potential partners include TV brands, cable and satellite companies, and telecommunication providers.Under these license arrangements, we generally have limited or no control over the amount and timing of resources these entities dedicate to the relationship.In the past, our licensed Roku TV partners have failed to meet their forecasts and anticipated market launch dates for distributing Roku TV models, and they may fail to meet their forecasts or such launches in the future.If our licensed Roku TV partners or service operator partners fail to meet their forecasts or such launches for distributing licensed streaming devices or choose to deploy competing streaming solutions within their product lines, our business may be harmed.We depend on a small number of content publishers for a majority of our streaming hours, and if we fail to maintain these relationships, our business could be harmed.*Historically, a small number of content publishers have accounted for a significant portion of the hours streamed on our platform.In the three months ended March 31, 2023, the top three streaming services represented over 50% of all hours streamed in the period.If, for any reason, we cease distributing channels that have historically streamed a large percentage of the aggregate streaming hours on our platform, our streaming hours, our active accounts, or Roku streaming device sales may be adversely affected, and our business may be harmed.', metadata={'_additional': {'id': '2a92b2bb-02a0-4e15-8b64-d7e04078a205'}, 'chunk_type': 'text', 'chunk_years_mentioned': [2023], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': '10-Q', 'data_source_link': 'https://www.sec.gov/Archives/edgar/data/1428439/000142843923000017', 'data_source_publish_date': '2023-01-01T00:00:00Z', 'data_source_uid': '0001428439-23-000017', 'title': 'ROKU INC | 10-Q 2023-Q1 '})]"
|
||||
"[Document(page_content='Company Name: ROKU INC\\nCompany Industry: CABLE & OTHER PAY TELEVISION SERVICES\\nArticle Title: Roku Is One of Fast Company\\'s Most Innovative Companies for 2023\\nText: The company launched several new devices, including the Roku Voice Remote Pro; upgraded its most premium player, the Roku Ultra; and expanded its products with a new line of smart home devices such as video doorbells, lights, and plugs integrated into the Roku ecosystem. Recently, the company announced it will launch Roku-branded TVs this spring to offer more choice and innovation to both consumers and Roku TV partners. Throughout 2022, Roku also updated its operating system (OS), the only OS purpose-built for TV, with more personalization features and enhancements across search, audio, and content discovery, launching The Buzz, Sports, and What to Watch, which provides tailored movie and TV recommendations on the Home Screen Menu. The company also released a new feature for streamers, Photo Streams, that allows customers to display and share photo albums through Roku streaming devices. Additionally, Roku unveiled Shoppable Ads, a new ad innovation that makes shopping on TV streaming as easy as it is on social media. Viewers simply press \"OK\" with their Roku remote on a shoppable ad and proceed to check out with their shipping and payment details pre-populated from Roku Pay, its proprietary payments platform. Walmart was the exclusive retailer for the launch, a first-of-its-kind partnership.', metadata={'chunk_type': 'text', 'chunk_years_mentioned': [2022, 2023], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': 'PressRelease', 'data_source_link': 'https://newsroom.roku.com/press-releases', 'data_source_publish_date': '2023-03-02T09:30:00-04:00', 'data_source_uid': '963d4a81-f58e-3093-af68-987fb1758c15', 'title': \"ROKU INC | Roku Is One of Fast Company's Most Innovative Companies for 2023\"}),\n",
|
||||
" Document(page_content='Company Name: ROKU INC\\nCompany Industry: CABLE & OTHER PAY TELEVISION SERVICES\\nArticle Title: Roku Is One of Fast Company\\'s Most Innovative Companies for 2023\\nText: Finally, Roku grew its content offering with thousands of apps and watching options for users, including content on The Roku Channel, a top five app by reach and engagement on the Roku platform in the U.S. in 2022. In November, Roku released its first feature film, \"WEIRD: The Weird Al\\' Yankovic Story,\" a biopic starring Daniel Radcliffe. Throughout the year, The Roku Channel added FAST channels from NBCUniversal and the National Hockey League, as well as an exclusive AMC channel featuring its signature drama \"Mad Men.\" This year, the company announced a deal with Warner Bros. Discovery, launching new channels that will include \"Westworld\" and \"The Bachelor,\" in addition to 2,000 hours of on-demand content. Read more about Roku\\'s journey here . Fast Company\\'s Most Innovative Companies issue (March/April 2023) is available online here , as well as in-app via iTunes and on newsstands beginning March 14. About Roku, Inc.\\nRoku pioneered streaming to the TV. We connect users to the streaming content they love, enable content publishers to build and monetize large audiences, and provide advertisers with unique capabilities to engage consumers. Roku streaming players and TV-related audio devices are available in the U.S. and in select countries through direct retail sales and licensing arrangements with service operators. Roku TV models are available in the U.S. and select countries through licensing arrangements with TV OEM brands.', metadata={'chunk_type': 'text', 'chunk_years_mentioned': [2022, 2023], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': 'PressRelease', 'data_source_link': 'https://newsroom.roku.com/press-releases', 'data_source_publish_date': '2023-03-02T09:30:00-04:00', 'data_source_uid': '963d4a81-f58e-3093-af68-987fb1758c15', 'title': \"ROKU INC | Roku Is One of Fast Company's Most Innovative Companies for 2023\"}),\n",
|
||||
" Document(page_content='Company Name: ROKU INC\\nCompany Industry: CABLE & OTHER PAY TELEVISION SERVICES\\nArticle Title: Roku\\'s New NFL Zone Gives Fans Easy Access to NFL Games Right On Time for 2023 Season\\nText: In partnership with the NFL, the new NFL Zone offers viewers an easy way to find where to watch NFL live games Today, Roku (NASDAQ: ROKU ) and the National Football League (NFL) announced the recently launched NFL Zone within the Roku Sports experience to kick off the 2023 NFL season. This strategic partnership between Roku and the NFL marks the first official league-branded zone within Roku\\'s Sports experience. Available now, the NFL Zone offers football fans a centralized location to find live and upcoming games, so they can spend less time figuring out where to watch the game and more time rooting for their favorite teams. Users can also tune in for weekly game previews, League highlights, and additional NFL content, all within the zone. This press release features multimedia. View the full release here: In partnership with the NFL, Roku\\'s new NFL Zone offers viewers an easy way to find where to watch NFL live games (Photo: Business Wire) \"Last year we introduced the Sports experience for our highly engaged sports audience, making it simpler for Roku users to watch sports programming,\" said Gidon Katz, President, Consumer Experience, at Roku. \"As we start the biggest sports season of the year, providing easy access to NFL games and content to our millions of users is a top priority for us. We look forward to fans immersing themselves within the NFL Zone and making it their destination to find NFL games.', metadata={'chunk_type': 'text', 'chunk_years_mentioned': [2023], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': 'PressRelease', 'data_source_link': 'https://newsroom.roku.com/press-releases', 'data_source_publish_date': '2023-09-12T09:00:00-04:00', 'data_source_uid': '963d4a81-f58e-3093-af68-987fb1758c15', 'title': \"ROKU INC | Roku's New NFL Zone Gives Fans Easy Access to NFL Games Right On Time for 2023 Season\"})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
||||
@@ -28,19 +28,29 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 11,
|
||||
"id": "63a8af5b",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[33mWARNING: You are using pip version 22.0.4; however, version 23.3 is available.\n",
|
||||
"You should consider upgrading via the '/Users/joe/projects/elastic/langchain/libs/langchain/.venv/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
|
||||
"\u001b[0m"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#!pip install lark elasticsearch"
|
||||
"#!pip install -qU lark elasticsearch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"id": "cb4a5787",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -60,7 +70,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 2,
|
||||
"id": "bcbe04d9",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -115,7 +125,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 3,
|
||||
"id": "86e34dbf",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -164,17 +174,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 4,
|
||||
"id": "38a126e9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='dinosaur' filter=None limit=None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
@@ -184,7 +187,7 @@
|
||||
" Document(page_content='A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea', metadata={'year': 2006, 'director': 'Satoshi Kon', 'rating': 8.6})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -196,24 +199,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 5,
|
||||
"id": "b19d4da0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig') limit=None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='A bunch of normal-sized women are supremely wholesome and some men pine after them', metadata={'year': 2019, 'director': 'Greta Gerwig', 'rating': 8.3})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -237,7 +233,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 6,
|
||||
"id": "bff36b88-b506-4877-9c63-e5a1a8d78e64",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@@ -256,19 +252,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 7,
|
||||
"id": "2758d229-4f97-499c-819f-888acaf8ee10",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='dinosaur' filter=None limit=2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
@@ -276,7 +265,7 @@
|
||||
" Document(page_content='Toys come alive and have a blast doing so', metadata={'year': 1995, 'genre': 'animated'})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -297,24 +286,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 8,
|
||||
"id": "e460da93",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='animated toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Operation(operator=<Operator.OR: 'or'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated'), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='comedy')]), Comparison(comparator=<Comparator.GTE: 'gte'>, attribute='year', value=1990)]) limit=None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Toys come alive and have a blast doing so', metadata={'year': 1995, 'genre': 'animated'})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -325,21 +307,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": null,
|
||||
"id": "0851fc42",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"ObjectApiResponse({'acknowledged': True})"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vectorstore.client.indices.delete(index=\"elasticsearch-self-query-demo\")"
|
||||
]
|
||||
@@ -361,7 +332,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
120
docs/docs/integrations/retrievers/singlestoredb.ipynb
Normal file
120
docs/docs/integrations/retrievers/singlestoredb.ipynb
Normal file
@@ -0,0 +1,120 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ab66dd43",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SingleStoreDB\n",
|
||||
"\n",
|
||||
">[SingleStoreDB](https://singlestore.com/) is a high-performance distributed SQL database that supports deployment both in the [cloud](https://www.singlestore.com/cloud/) and on-premises. It provides vector storage, and vector functions including [dot_product](https://docs.singlestore.com/managed-service/en/reference/sql-reference/vector-functions/dot_product.html) and [euclidean_distance](https://docs.singlestore.com/managed-service/en/reference/sql-reference/vector-functions/euclidean_distance.html), thereby supporting AI applications that require text similarity matching. \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"This notebook shows how to use a retriever that uses `SingleStoreDB`.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "51b49135-a61a-49e8-869d-7c1d76794cd7",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Establishing a connection to the database is facilitated through the singlestoredb Python connector.\n",
|
||||
"# Please ensure that this connector is installed in your working environment.\n",
|
||||
"!pip install singlestoredb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aaf80e7f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create Retriever from vector store"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bcb3c8c2",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"\n",
|
||||
"# We want to use OpenAIEmbeddings so we have to get the OpenAI API Key.\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n",
|
||||
"\n",
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import SingleStoreDB\n",
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"\n",
|
||||
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"embeddings = OpenAIEmbeddings()\n",
|
||||
"\n",
|
||||
"# Setup connection url as environment variable\n",
|
||||
"os.environ[\"SINGLESTOREDB_URL\"] = \"root:pass@localhost:3306/db\"\n",
|
||||
"\n",
|
||||
"# Load documents to the store\n",
|
||||
"docsearch = SingleStoreDB.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" table_name=\"notebook\", # use table with a custom name\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# create retriever from the vector store\n",
|
||||
"retriever = docsearch.as_retriever(search_kwargs={\"k\": 2})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fc0915db",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Search with retriever"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "b605284d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = retriever.get_relevant_documents(\"What did the president say about Ketanji Brown Jackson\")\n",
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -139,7 +139,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 1,
|
||||
"id": "67ab8afa-f7c6-4fbf-b596-cb512da949da",
|
||||
"metadata": {
|
||||
"id": "67ab8afa-f7c6-4fbf-b596-cb512da949da",
|
||||
@@ -172,7 +172,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"id": "aac9563e",
|
||||
"metadata": {
|
||||
"id": "aac9563e",
|
||||
@@ -186,7 +186,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"id": "a3c3999a",
|
||||
"metadata": {
|
||||
"id": "a3c3999a",
|
||||
@@ -207,7 +207,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 4,
|
||||
"id": "12eb86d8",
|
||||
"metadata": {
|
||||
"id": "12eb86d8",
|
||||
@@ -218,7 +218,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[Document(page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.', metadata={'source': '../../modules/state_of_the_union.txt'}), Document(page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.', metadata={'source': '../../modules/state_of_the_union.txt', 'date': '2016-01-01', 'rating': 2, 'author': 'John Doe'}), Document(page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.', metadata={'source': '../../modules/state_of_the_union.txt', 'date': '2010-01-01', 'rating': 1, 'author': 'John Doe'}), Document(page_content='As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential. \\n\\nWhile it often appears that we never agree, that isn’t true. I signed 80 bipartisan bills into law last year. From preventing government shutdowns to protecting Asian-Americans from still-too-common hate crimes to reforming military justice.', metadata={'source': '../../modules/state_of_the_union.txt'})]\n"
|
||||
"[Document(page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.', metadata={'source': '../../modules/state_of_the_union.txt'}), Document(page_content='As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential. \\n\\nWhile it often appears that we never agree, that isn’t true. I signed 80 bipartisan bills into law last year. From preventing government shutdowns to protecting Asian-Americans from still-too-common hate crimes to reforming military justice.', metadata={'source': '../../modules/state_of_the_union.txt'}), Document(page_content='A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \\n\\nAnd if we are to advance liberty and justice, we need to secure the Border and fix the immigration system.', metadata={'source': '../../modules/state_of_the_union.txt'}), Document(page_content='This is personal to me and Jill, to Kamala, and to so many of you. \\n\\nCancer is the #2 cause of death in America–second only to heart disease. \\n\\nLast month, I announced our plan to supercharge \\nthe Cancer Moonshot that President Obama asked me to lead six years ago. \\n\\nOur goal is to cut the cancer death rate by at least 50% over the next 25 years, turn more cancers from death sentences into treatable diseases. \\n\\nMore support for patients and families.', metadata={'source': '../../modules/state_of_the_union.txt'})]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -247,7 +247,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 5,
|
||||
"id": "5d076412",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -284,12 +284,13 @@
|
||||
"## Filtering Metadata\n",
|
||||
"With metadata added to the documents, you can add metadata filtering at query time. \n",
|
||||
"\n",
|
||||
"### Example: Filter by keyword"
|
||||
"### Example: Filter by Exact keyword\n",
|
||||
"Notice: We are using the keyword subfield thats not analyzed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"execution_count": 6,
|
||||
"id": "b2a4bd1b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -297,12 +298,42 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'source': '../../modules/state_of_the_union.txt', 'date': '2010-01-01', 'rating': 1, 'author': 'John Doe', 'geo_location': {'lat': 40.12, 'lon': -71.34}}\n"
|
||||
"{'source': '../../modules/state_of_the_union.txt', 'date': '2016-01-01', 'rating': 2, 'author': 'John Doe'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"docs = db.similarity_search(query, filter=[{ \"match\": { \"metadata.author\": \"John Doe\"}}])\n",
|
||||
"docs = db.similarity_search(query, filter=[{ \"term\": { \"metadata.author.keyword\": \"John Doe\"}}])\n",
|
||||
"print(docs[0].metadata)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1898ab77",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example: Filter by Partial Match\n",
|
||||
"This example shows how to filter by partial match. This is useful when you don't know the exact value of the metadata field. For example, if you want to filter by the metadata field `author` and you don't know the exact value of the author, you can use a partial match to filter by the author's last name. Fuzzy matching is also supported.\n",
|
||||
"\n",
|
||||
"\"Jon\" matches on \"John Doe\" as \"Jon\" is a close match to \"John\" token."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "f3d294ff",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'source': '../../modules/state_of_the_union.txt', 'date': '2016-01-01', 'rating': 2, 'author': 'John Doe'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"docs = db.similarity_search(query, filter=[{ \"match\": { \"metadata.author\": { \"query\": \"Jon\", \"fuzziness\": \"AUTO\" } }}])\n",
|
||||
"print(docs[0].metadata)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -4,9 +4,8 @@ This output parser can be used when you want to return a list of comma-separated
|
||||
|
||||
```python
|
||||
from langchain.output_parsers import CommaSeparatedListOutputParser
|
||||
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
output_parser = CommaSeparatedListOutputParser()
|
||||
|
||||
|
||||
BIN
docs/static/img/qa_privacy_protection.png
vendored
BIN
docs/static/img/qa_privacy_protection.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 150 KiB After Width: | Height: | Size: 185 KiB |
@@ -3848,8 +3848,8 @@
|
||||
"destination": "/docs/additional_resources/dependents"
|
||||
},
|
||||
{
|
||||
"source": "docs/integrations/retrievers/google_cloud_enterprise_search",
|
||||
"destination": "docs/integrations/retrievers/google_vertex_ai_search"
|
||||
"source": "/docs/integrations/retrievers/google_cloud_enterprise_search",
|
||||
"destination": "/docs/integrations/retrievers/google_vertex_ai_search"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from packaging.version import parse
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
@@ -142,7 +144,7 @@ def _get_user_props(metadata: Any) -> Any:
|
||||
return user_props_ctx.get()
|
||||
|
||||
metadata = metadata or {}
|
||||
return metadata.get("user_props")
|
||||
return metadata.get("user_props", None)
|
||||
|
||||
|
||||
def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]:
|
||||
@@ -191,6 +193,8 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
__api_url: str
|
||||
__app_id: str
|
||||
__verbose: bool
|
||||
__llmonitor_version: str
|
||||
__has_valid_config: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -200,37 +204,58 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.__api_url = api_url or os.getenv("LLMONITOR_API_URL") or DEFAULT_API_URL
|
||||
self.__has_valid_config = True
|
||||
|
||||
try:
|
||||
import llmonitor
|
||||
|
||||
self.__llmonitor_version = importlib.metadata.version("llmonitor")
|
||||
self.__track_event = llmonitor.track_event
|
||||
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
"""[LLMonitor] To use the LLMonitor callback handler you need to
|
||||
have the `llmonitor` Python package installed. Please install it
|
||||
with `pip install llmonitor`"""
|
||||
)
|
||||
self.__has_valid_config = False
|
||||
|
||||
if parse(self.__llmonitor_version) < parse("0.0.20"):
|
||||
warnings.warn(
|
||||
f"""[LLMonitor] The installed `llmonitor` version is
|
||||
{self.__llmonitor_version} but `LLMonitorCallbackHandler` requires
|
||||
at least version 0.0.20 upgrade `llmonitor` with `pip install
|
||||
--upgrade llmonitor`"""
|
||||
)
|
||||
self.__has_valid_config = False
|
||||
|
||||
self.__has_valid_config = True
|
||||
|
||||
self.__api_url = api_url or os.getenv("LLMONITOR_API_URL") or DEFAULT_API_URL
|
||||
self.__verbose = verbose or bool(os.getenv("LLMONITOR_VERBOSE"))
|
||||
|
||||
_app_id = app_id or os.getenv("LLMONITOR_APP_ID")
|
||||
if _app_id is None:
|
||||
raise ValueError(
|
||||
"""app_id must be provided either as an argument or as
|
||||
warnings.warn(
|
||||
"""[LLMonitor] app_id must be provided either as an argument or as
|
||||
an environment variable"""
|
||||
)
|
||||
self.__app_id = _app_id
|
||||
self.__has_valid_config = False
|
||||
else:
|
||||
self.__app_id = _app_id
|
||||
|
||||
if self.__has_valid_config is False:
|
||||
return None
|
||||
|
||||
try:
|
||||
res = requests.get(f"{self.__api_url}/api/app/{self.__app_id}")
|
||||
if not res.ok:
|
||||
raise ConnectionError()
|
||||
except Exception as e:
|
||||
raise ConnectionError(
|
||||
f"Could not connect to the LLMonitor API at {self.__api_url}"
|
||||
) from e
|
||||
|
||||
def __send_event(self, event: Dict[str, Any]) -> None:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
event = {**event, "app": self.__app_id, "timestamp": str(datetime.utcnow())}
|
||||
|
||||
if self.__verbose:
|
||||
print("llmonitor_callback", event)
|
||||
|
||||
data = {"events": event}
|
||||
requests.post(headers=headers, url=f"{self.__api_url}/api/report", json=data)
|
||||
except Exception:
|
||||
warnings.warn(
|
||||
f"""[LLMonitor] Could not connect to the LLMonitor API at
|
||||
{self.__api_url}"""
|
||||
)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
@@ -243,27 +268,28 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
name = kwargs.get("invocation_params", {}).get("model_name")
|
||||
input = _parse_input(prompts)
|
||||
|
||||
event = {
|
||||
"event": "start",
|
||||
"type": "llm",
|
||||
"userId": user_id,
|
||||
"runId": str(run_id),
|
||||
"parentRunId": str(parent_run_id) if parent_run_id else None,
|
||||
"input": _parse_input(prompts),
|
||||
"name": kwargs.get("invocation_params", {}).get("model_name"),
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
}
|
||||
if user_props:
|
||||
event["userProps"] = user_props
|
||||
|
||||
self.__send_event(event)
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_llm_start: {e}")
|
||||
warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}")
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
@@ -276,28 +302,29 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
name = kwargs.get("invocation_params", {}).get("model_name")
|
||||
input = _parse_lc_messages(messages[0])
|
||||
|
||||
event = {
|
||||
"event": "start",
|
||||
"type": "llm",
|
||||
"userId": user_id,
|
||||
"runId": str(run_id),
|
||||
"parentRunId": str(parent_run_id) if parent_run_id else None,
|
||||
"input": _parse_lc_messages(messages[0]),
|
||||
"name": kwargs.get("invocation_params", {}).get("model_name"),
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
}
|
||||
if user_props:
|
||||
event["userProps"] = user_props
|
||||
|
||||
self.__send_event(event)
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"[LLMonitor] An error occurred in on_chat_model_start: " f"{e}"
|
||||
f"[LLMonitor] An error occurred in on_chat_model_start: {e}"
|
||||
)
|
||||
|
||||
def on_llm_end(
|
||||
@@ -308,9 +335,11 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
|
||||
try:
|
||||
token_usage = (response.llm_output or {}).get("token_usage", {})
|
||||
|
||||
parsed_output = [
|
||||
{
|
||||
"text": generation.text,
|
||||
@@ -330,20 +359,19 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
for generation in response.generations[0]
|
||||
]
|
||||
|
||||
event = {
|
||||
"event": "end",
|
||||
"type": "llm",
|
||||
"runId": str(run_id),
|
||||
"parent_run_id": str(parent_run_id) if parent_run_id else None,
|
||||
"output": parsed_output,
|
||||
"tokensUsage": {
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=parsed_output,
|
||||
token_usage={
|
||||
"prompt": token_usage.get("prompt_tokens"),
|
||||
"completion": token_usage.get("completion_tokens"),
|
||||
},
|
||||
}
|
||||
self.__send_event(event)
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_llm_end: {e}")
|
||||
warnings.warn(f"[LLMonitor] An error occurred in on_llm_end: {e}")
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
@@ -356,27 +384,27 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
name = serialized.get("name")
|
||||
|
||||
event = {
|
||||
"event": "start",
|
||||
"type": "tool",
|
||||
"userId": user_id,
|
||||
"runId": str(run_id),
|
||||
"parentRunId": str(parent_run_id) if parent_run_id else None,
|
||||
"name": serialized.get("name"),
|
||||
"input": input_str,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
}
|
||||
if user_props:
|
||||
event["userProps"] = user_props
|
||||
|
||||
self.__send_event(event)
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input_str,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_tool_start: {e}")
|
||||
warnings.warn(f"[LLMonitor] An error occurred in on_tool_start: {e}")
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
@@ -387,17 +415,18 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
tags: Union[List[str], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
event = {
|
||||
"event": "end",
|
||||
"type": "tool",
|
||||
"runId": str(run_id),
|
||||
"parent_run_id": str(parent_run_id) if parent_run_id else None,
|
||||
"output": output,
|
||||
}
|
||||
self.__send_event(event)
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=output,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_tool_end: {e}")
|
||||
warnings.warn(f"[LLMonitor] An error occurred in on_tool_end: {e}")
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
@@ -410,6 +439,8 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
metadata: Union[Dict[str, Any], None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
name = serialized.get("id", [None, None, None, None])[3]
|
||||
type = "chain"
|
||||
@@ -419,35 +450,32 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
if agentName is None:
|
||||
agentName = metadata.get("agentName")
|
||||
|
||||
if name == "AgentExecutor" or name == "PlanAndExecute":
|
||||
type = "agent"
|
||||
if agentName is not None:
|
||||
type = "agent"
|
||||
name = agentName
|
||||
if name == "AgentExecutor" or name == "PlanAndExecute":
|
||||
type = "agent"
|
||||
|
||||
if parent_run_id is not None:
|
||||
type = "chain"
|
||||
|
||||
user_id = _get_user_id(metadata)
|
||||
user_props = _get_user_props(metadata)
|
||||
input = _parse_input(inputs)
|
||||
|
||||
event = {
|
||||
"event": "start",
|
||||
"type": type,
|
||||
"userId": user_id,
|
||||
"runId": str(run_id),
|
||||
"parentRunId": str(parent_run_id) if parent_run_id else None,
|
||||
"input": _parse_input(inputs),
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"name": name,
|
||||
}
|
||||
if user_props:
|
||||
event["userProps"] = user_props
|
||||
|
||||
self.__send_event(event)
|
||||
self.__track_event(
|
||||
type,
|
||||
"start",
|
||||
user_id=user_id,
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
user_props=user_props,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_chain_start: {e}")
|
||||
warnings.warn(f"[LLMonitor] An error occurred in on_chain_start: {e}")
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
@@ -457,14 +485,18 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
event = {
|
||||
"event": "end",
|
||||
"type": "chain",
|
||||
"runId": str(run_id),
|
||||
"output": _parse_output(outputs),
|
||||
}
|
||||
self.__send_event(event)
|
||||
output = _parse_output(outputs)
|
||||
|
||||
self.__track_event(
|
||||
"chain",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=output,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_chain_end: {e}")
|
||||
|
||||
@@ -476,16 +508,20 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
event = {
|
||||
"event": "start",
|
||||
"type": "tool",
|
||||
"runId": str(run_id),
|
||||
"parentRunId": str(parent_run_id) if parent_run_id else None,
|
||||
"name": action.tool,
|
||||
"input": _parse_input(action.tool_input),
|
||||
}
|
||||
self.__send_event(event)
|
||||
name = action.tool
|
||||
input = _parse_input(action.tool_input)
|
||||
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"start",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
name=name,
|
||||
input=input,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_agent_action: {e}")
|
||||
|
||||
@@ -497,15 +533,18 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
event = {
|
||||
"event": "end",
|
||||
"type": "agent",
|
||||
"runId": str(run_id),
|
||||
"parentRunId": str(parent_run_id) if parent_run_id else None,
|
||||
"output": _parse_output(finish.return_values),
|
||||
}
|
||||
self.__send_event(event)
|
||||
output = _parse_output(finish.return_values)
|
||||
|
||||
self.__track_event(
|
||||
"agent",
|
||||
"end",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
output=output,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_agent_finish: {e}")
|
||||
|
||||
@@ -517,15 +556,16 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
event = {
|
||||
"event": "error",
|
||||
"type": "chain",
|
||||
"runId": str(run_id),
|
||||
"parent_run_id": str(parent_run_id) if parent_run_id else None,
|
||||
"error": {"message": str(error), "stack": traceback.format_exc()},
|
||||
}
|
||||
self.__send_event(event)
|
||||
self.__track_event(
|
||||
"chain",
|
||||
"error",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
error={"message": str(error), "stack": traceback.format_exc()},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_chain_error: {e}")
|
||||
|
||||
@@ -537,15 +577,16 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
event = {
|
||||
"event": "error",
|
||||
"type": "tool",
|
||||
"runId": str(run_id),
|
||||
"parent_run_id": str(parent_run_id) if parent_run_id else None,
|
||||
"error": {"message": str(error), "stack": traceback.format_exc()},
|
||||
}
|
||||
self.__send_event(event)
|
||||
self.__track_event(
|
||||
"tool",
|
||||
"error",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
error={"message": str(error), "stack": traceback.format_exc()},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_tool_error: {e}")
|
||||
|
||||
@@ -557,15 +598,16 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
||||
parent_run_id: Union[UUID, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.__has_valid_config is False:
|
||||
return
|
||||
try:
|
||||
event = {
|
||||
"event": "error",
|
||||
"type": "llm",
|
||||
"runId": str(run_id),
|
||||
"parent_run_id": str(parent_run_id) if parent_run_id else None,
|
||||
"error": {"message": str(error), "stack": traceback.format_exc()},
|
||||
}
|
||||
self.__send_event(event)
|
||||
self.__track_event(
|
||||
"llm",
|
||||
"error",
|
||||
run_id=str(run_id),
|
||||
parent_run_id=str(parent_run_id) if parent_run_id else None,
|
||||
error={"message": str(error), "stack": traceback.format_exc()},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"[LLMonitor] An error occurred in on_llm_error: {e}")
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from langchain.utilities.openapi import OpenAPISpec
|
||||
from langchain.utils.input import get_colored_text
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openapi_schema_pydantic import Parameter
|
||||
from openapi_pydantic import Parameter
|
||||
|
||||
|
||||
def _get_description(o: Any, prefer_short: bool) -> Optional[str]:
|
||||
|
||||
@@ -38,6 +38,7 @@ from langchain.chat_models.minimax import MiniMaxChat
|
||||
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
|
||||
from langchain.chat_models.ollama import ChatOllama
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
|
||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||
from langchain.chat_models.vertexai import ChatVertexAI
|
||||
from langchain.chat_models.yandex import ChatYandexGPT
|
||||
@@ -63,6 +64,7 @@ __all__ = [
|
||||
"ErnieBotChat",
|
||||
"ChatJavelinAIGateway",
|
||||
"ChatKonko",
|
||||
"PaiEasChatEndpoint",
|
||||
"QianfanChatEndpoint",
|
||||
"ChatFireworks",
|
||||
"ChatYandexGPT",
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -38,12 +37,10 @@ from langchain.schema import (
|
||||
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
@@ -79,7 +76,7 @@ async def _agenerate_from_stream(
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""Base class for Chat models."""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
@@ -116,9 +113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
"""Get the output type for this runnable."""
|
||||
return Union[
|
||||
HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage
|
||||
]
|
||||
return AnyMessage
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
@@ -140,23 +135,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageChunk:
|
||||
) -> BaseMessage:
|
||||
config = config or {}
|
||||
return cast(
|
||||
BaseMessageChunk,
|
||||
cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message,
|
||||
)
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@@ -165,7 +157,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageChunk:
|
||||
) -> BaseMessage:
|
||||
config = config or {}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)],
|
||||
@@ -176,9 +168,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
return cast(
|
||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
)
|
||||
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
||||
def stream(
|
||||
self,
|
||||
@@ -190,7 +180,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if type(self)._stream == BaseChatModel._stream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
@@ -241,7 +233,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if type(self)._astream == BaseChatModel._astream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
|
||||
324
libs/langchain/langchain/chat_models/pai_eas_endpoint.py
Normal file
324
libs/langchain/langchain/chat_models/pai_eas_endpoint.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PaiEasChatEndpoint(BaseChatModel):
|
||||
"""Eas LLM Service chat model API.
|
||||
|
||||
To use, must have a deployed eas chat llm service on AliCloud. One can set the
|
||||
environment variable ``eas_service_url`` and ``eas_service_token`` set with your eas
|
||||
service url and service token.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import PaiEasChatEndpoint
|
||||
eas_chat_endpoint = PaiEasChatEndpoint(
|
||||
eas_service_url="your_service_url",
|
||||
eas_service_token="your_service_token"
|
||||
)
|
||||
"""
|
||||
|
||||
"""PAI-EAS Service URL"""
|
||||
eas_service_url: str
|
||||
|
||||
"""PAI-EAS Service TOKEN"""
|
||||
eas_service_token: str
|
||||
|
||||
"""PAI-EAS Service Infer Params"""
|
||||
max_new_tokens: Optional[int] = 512
|
||||
temperature: Optional[float] = 0.8
|
||||
top_p: Optional[float] = 0.1
|
||||
top_k: Optional[int] = 10
|
||||
do_sample: Optional[bool] = False
|
||||
use_cache: Optional[bool] = True
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
"""Enable stream chat mode."""
|
||||
streaming: bool = False
|
||||
|
||||
"""Key/value arguments to pass to the model. Reserved for future use"""
|
||||
model_kwargs: Optional[dict] = None
|
||||
|
||||
version: Optional[str] = "2.0"
|
||||
|
||||
timeout: Optional[int] = 5000
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["eas_service_url"] = get_from_dict_or_env(
|
||||
values, "eas_service_url", "EAS_SERVICE_URL"
|
||||
)
|
||||
values["eas_service_token"] = get_from_dict_or_env(
|
||||
values, "eas_service_token", "EAS_SERVICE_TOKEN"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
"eas_service_url": self.eas_service_url,
|
||||
"eas_service_token": self.eas_service_token,
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "pai_eas_chat_endpoint"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Cohere API."""
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"stop_sequences": [],
|
||||
"do_sample": self.do_sample,
|
||||
"use_cache": self.use_cache,
|
||||
}
|
||||
|
||||
def _invocation_params(
|
||||
self, stop_sequences: Optional[List[str]], **kwargs: Any
|
||||
) -> dict:
|
||||
params = self._default_params
|
||||
if self.model_kwargs:
|
||||
params.update(self.model_kwargs)
|
||||
if self.stop_sequences is not None and stop_sequences is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop_sequences is not None:
|
||||
params["stop"] = self.stop_sequences
|
||||
else:
|
||||
params["stop"] = stop_sequences
|
||||
return {**params, **kwargs}
|
||||
|
||||
def format_request_payload(
|
||||
self, messages: List[BaseMessage], **model_kwargs: Any
|
||||
) -> dict:
|
||||
prompt: Dict[str, Any] = {}
|
||||
user_content: List[str] = []
|
||||
assistant_content: List[str] = []
|
||||
|
||||
for message in messages:
|
||||
"""Converts message to a dict according to role"""
|
||||
if isinstance(message, HumanMessage):
|
||||
user_content = user_content + [message.content]
|
||||
elif isinstance(message, AIMessage):
|
||||
assistant_content = assistant_content + [message.content]
|
||||
elif isinstance(message, SystemMessage):
|
||||
prompt["system_prompt"] = message.content
|
||||
elif isinstance(message, ChatMessage) and message.role in [
|
||||
"user",
|
||||
"assistant",
|
||||
"system",
|
||||
]:
|
||||
if message.role == "system":
|
||||
prompt["system_prompt"] = message.content
|
||||
elif message.role == "user":
|
||||
user_content = user_content + [message.content]
|
||||
elif message.role == "assistant":
|
||||
assistant_content = assistant_content + [message.content]
|
||||
else:
|
||||
supported = ",".join([role for role in ["user", "assistant", "system"]])
|
||||
raise ValueError(
|
||||
f"""Received unsupported role.
|
||||
Supported roles for the LLaMa Foundation Model: {supported}"""
|
||||
)
|
||||
prompt["prompt"] = user_content[len(user_content) - 1]
|
||||
history = [
|
||||
history_item
|
||||
for _, history_item in enumerate(zip(user_content[:-1], assistant_content))
|
||||
]
|
||||
|
||||
prompt["history"] = history
|
||||
|
||||
return {**prompt, **model_kwargs}
|
||||
|
||||
def _format_response_payload(
|
||||
self, output: bytes, stop_sequences: Optional[List[str]]
|
||||
) -> str:
|
||||
"""Formats response"""
|
||||
try:
|
||||
text = json.loads(output)["response"]
|
||||
if stop_sequences:
|
||||
text = enforce_stop_tokens(text, stop_sequences)
|
||||
return text
|
||||
except Exception as e:
|
||||
if isinstance(e, json.decoder.JSONDecodeError):
|
||||
return output.decode("utf-8")
|
||||
raise e
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
|
||||
request_payload = self.format_request_payload(messages, **params)
|
||||
response_payload = self._call_eas(request_payload)
|
||||
generated_text = self._format_response_payload(response_payload, params["stop"])
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(generated_text)
|
||||
|
||||
return generated_text
|
||||
|
||||
def _call_eas(self, query_body: dict) -> Any:
|
||||
"""Generate text from the eas service."""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"{self.eas_service_token}",
|
||||
}
|
||||
|
||||
# make request
|
||||
response = requests.post(
|
||||
self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed with status code {response.status_code}"
|
||||
f" and message {response.text}"
|
||||
)
|
||||
|
||||
return response.text
|
||||
|
||||
def _call_eas_stream(self, query_body: dict) -> Any:
|
||||
"""Generate text from the eas service."""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"{self.eas_service_token}",
|
||||
}
|
||||
|
||||
# make request
|
||||
response = requests.post(
|
||||
self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed with status code {response.status_code}"
|
||||
f" and message {response.text}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _convert_chunk_to_message_message(
|
||||
self,
|
||||
chunk: str,
|
||||
) -> AIMessageChunk:
|
||||
data = json.loads(chunk.encode("utf-8"))
|
||||
return AIMessageChunk(content=data.get("response", ""))
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
|
||||
request_payload = self.format_request_payload(messages, **params)
|
||||
request_payload["use_stream_chat"] = True
|
||||
|
||||
response = self._call_eas_stream(request_payload)
|
||||
for chunk in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if chunk:
|
||||
content = self._convert_chunk_to_message_message(chunk)
|
||||
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in params["stop"]:
|
||||
if stop_seq in content.content:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
if stop_seq_found:
|
||||
content.content = content.content[
|
||||
: content.content.index(stop_seq_found)
|
||||
]
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(content.content)
|
||||
yield ChatGenerationChunk(message=content)
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if stream if stream is not None else self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
generation = chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
@@ -342,6 +342,12 @@ def _import_openlm() -> Any:
|
||||
return OpenLM
|
||||
|
||||
|
||||
def _import_pai_eas_endpoint() -> Any:
|
||||
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
|
||||
|
||||
return PaiEasEndpoint
|
||||
|
||||
|
||||
def _import_petals() -> Any:
|
||||
from langchain.llms.petals import Petals
|
||||
|
||||
@@ -593,6 +599,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_openllm()
|
||||
elif name == "OpenLM":
|
||||
return _import_openlm()
|
||||
elif name == "PaiEasEndpoint":
|
||||
return _import_pai_eas_endpoint()
|
||||
elif name == "Petals":
|
||||
return _import_petals()
|
||||
elif name == "PipelineAI":
|
||||
@@ -703,6 +711,7 @@ __all__ = [
|
||||
"OpenAIChat",
|
||||
"OpenLLM",
|
||||
"OpenLM",
|
||||
"PaiEasEndpoint",
|
||||
"Petals",
|
||||
"PipelineAI",
|
||||
"Predibase",
|
||||
@@ -780,6 +789,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"ollama": _import_ollama,
|
||||
"openai": _import_openai,
|
||||
"openlm": _import_openlm,
|
||||
"pai_eas_endpoint": _import_pai_eas_endpoint,
|
||||
"petals": _import_petals,
|
||||
"pipelineai": _import_pipelineai,
|
||||
"predibase": _import_predibase,
|
||||
|
||||
@@ -6,9 +6,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.llms.base import LLM, create_base_retry_decorator
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.utils.env import get_from_dict_or_env
|
||||
|
||||
|
||||
@@ -140,42 +138,6 @@ class Fireworks(LLM):
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[str]:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
generation: Optional[GenerationChunk] = None
|
||||
for chunk in self._stream(prompt):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
generation: Optional[GenerationChunk] = None
|
||||
async for chunk in self._astream(prompt):
|
||||
yield chunk.text
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: Fireworks,
|
||||
|
||||
@@ -91,7 +91,7 @@ class Modal(LLM):
|
||||
if prompt in response.json()["prompt"]:
|
||||
response_json = response.json()
|
||||
except KeyError:
|
||||
raise ValueError("LangChain requires 'prompt' key in response.")
|
||||
raise KeyError("LangChain requires 'prompt' key in response.")
|
||||
text = response_json["prompt"]
|
||||
if stop is not None:
|
||||
# I believe this is required since the stop tokens
|
||||
|
||||
240
libs/langchain/langchain/llms/pai_eas_endpoint.py
Normal file
240
libs/langchain/langchain/llms/pai_eas_endpoint.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PaiEasEndpoint(LLM):
|
||||
"""Langchain LLM class to help to access eass llm service.
|
||||
|
||||
To use this endpoint, must have a deployed eas chat llm service on PAI AliCloud.
|
||||
One can set the environment variable ``eas_service_url`` and ``eas_service_token``.
|
||||
The environment variables can set with your eas service url and service token.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
|
||||
eas_chat_endpoint = PaiEasChatEndpoint(
|
||||
eas_service_url="your_service_url",
|
||||
eas_service_token="your_service_token"
|
||||
)
|
||||
"""
|
||||
|
||||
"""PAI-EAS Service URL"""
|
||||
eas_service_url: str
|
||||
|
||||
"""PAI-EAS Service TOKEN"""
|
||||
eas_service_token: str
|
||||
|
||||
"""PAI-EAS Service Infer Params"""
|
||||
max_new_tokens: Optional[int] = 512
|
||||
temperature: Optional[float] = 0.95
|
||||
top_p: Optional[float] = 0.1
|
||||
top_k: Optional[int] = 0
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
"""Enable stream chat mode."""
|
||||
streaming: bool = False
|
||||
|
||||
"""Key/value arguments to pass to the model. Reserved for future use"""
|
||||
model_kwargs: Optional[dict] = None
|
||||
|
||||
version: Optional[str] = "2.0"
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["eas_service_url"] = get_from_dict_or_env(
|
||||
values, "eas_service_url", "EAS_SERVICE_URL"
|
||||
)
|
||||
values["eas_service_token"] = get_from_dict_or_env(
|
||||
values, "eas_service_token", "EAS_SERVICE_TOKEN"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "pai_eas_endpoint"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Cohere API."""
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"stop_sequences": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
"eas_service_url": self.eas_service_url,
|
||||
"eas_service_token": self.eas_service_token,
|
||||
**_model_kwargs,
|
||||
}
|
||||
|
||||
def _invocation_params(
|
||||
self, stop_sequences: Optional[List[str]], **kwargs: Any
|
||||
) -> dict:
|
||||
params = self._default_params
|
||||
if self.stop_sequences is not None and stop_sequences is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop_sequences is not None:
|
||||
params["stop"] = self.stop_sequences
|
||||
else:
|
||||
params["stop"] = stop_sequences
|
||||
if self.model_kwargs:
|
||||
params.update(self.model_kwargs)
|
||||
return {**params, **kwargs}
|
||||
|
||||
@staticmethod
|
||||
def _process_response(
|
||||
response: Any, stop: Optional[List[str]], version: Optional[str]
|
||||
) -> str:
|
||||
if version == "1.0":
|
||||
text = response
|
||||
else:
|
||||
text = response["response"]
|
||||
|
||||
if stop:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return "".join(text)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
prompt = prompt.strip()
|
||||
response = None
|
||||
try:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(prompt, stop, run_manager, **params):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
else:
|
||||
response = self._call_eas(prompt, params)
|
||||
_stop = params.get("stop")
|
||||
return self._process_response(response, _stop, self.version)
|
||||
except Exception as error:
|
||||
raise ValueError(f"Error raised by the service: {error}")
|
||||
|
||||
def _call_eas(self, prompt: str = "", params: Dict = {}) -> Any:
|
||||
"""Generate text from the eas service."""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"{self.eas_service_token}",
|
||||
}
|
||||
if self.version == "1.0":
|
||||
body = {
|
||||
"input_ids": f"{prompt}",
|
||||
}
|
||||
else:
|
||||
body = {
|
||||
"prompt": f"{prompt}",
|
||||
}
|
||||
|
||||
# add params to body
|
||||
for key, value in params.items():
|
||||
body[key] = value
|
||||
|
||||
# make request
|
||||
response = requests.post(self.eas_service_url, headers=headers, json=body)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed with status code {response.status_code}"
|
||||
f" and message {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
return json.loads(response.text)
|
||||
except Exception as e:
|
||||
if isinstance(e, json.decoder.JSONDecodeError):
|
||||
return response.text
|
||||
raise e
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
|
||||
headers = {
|
||||
"User-Agent": "Test Client",
|
||||
"Authorization": f"{self.eas_service_token}",
|
||||
}
|
||||
|
||||
if self.version == "1.0":
|
||||
pload = {"input_ids": prompt, **invocation_params}
|
||||
response = requests.post(
|
||||
self.eas_service_url, headers=headers, json=pload, stream=True
|
||||
)
|
||||
|
||||
res = GenerationChunk(text=response.text)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(res.text)
|
||||
|
||||
# yield text, if any
|
||||
yield res
|
||||
else:
|
||||
pload = {"prompt": prompt, "use_stream_chat": "True", **invocation_params}
|
||||
|
||||
response = requests.post(
|
||||
self.eas_service_url, headers=headers, json=pload, stream=True
|
||||
)
|
||||
|
||||
for chunk in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
output = data["response"]
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if stop_seq in output:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
if stop_seq_found:
|
||||
text = output[: output.index(stop_seq_found)]
|
||||
else:
|
||||
text = output
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
res = GenerationChunk(text=text)
|
||||
yield res
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(res.text)
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
||||
@@ -1,6 +1,9 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.load.serializable import Serializable
|
||||
@@ -128,3 +131,75 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Async call the Yandex GPT model and return the output.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
import grpc
|
||||
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
||||
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions
|
||||
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import (
|
||||
InstructRequest,
|
||||
InstructResponse,
|
||||
)
|
||||
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
|
||||
TextGenerationAsyncServiceStub,
|
||||
)
|
||||
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
|
||||
from yandex.cloud.operation.operation_service_pb2_grpc import (
|
||||
OperationServiceStub,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
||||
) from e
|
||||
operation_api_url = "operation.api.cloud.yandex.net:443"
|
||||
channel_credentials = grpc.ssl_channel_credentials()
|
||||
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
|
||||
request = InstructRequest(
|
||||
model=self.model_name,
|
||||
request_text=prompt,
|
||||
generation_options=GenerationOptions(
|
||||
temperature=DoubleValue(value=self.temperature),
|
||||
max_tokens=Int64Value(value=self.max_tokens),
|
||||
),
|
||||
)
|
||||
stub = TextGenerationAsyncServiceStub(channel)
|
||||
if self.iam_token:
|
||||
metadata = (("authorization", f"Bearer {self.iam_token}"),)
|
||||
else:
|
||||
metadata = (("authorization", f"Api-Key {self.api_key}"),)
|
||||
operation = await stub.Instruct(request, metadata=metadata)
|
||||
async with grpc.aio.secure_channel(
|
||||
operation_api_url, channel_credentials
|
||||
) as operation_channel:
|
||||
operation_stub = OperationServiceStub(operation_channel)
|
||||
while not operation.done:
|
||||
await asyncio.sleep(1)
|
||||
operation_request = GetOperationRequest(operation_id=operation.id)
|
||||
operation = await operation_stub.Get(
|
||||
operation_request, metadata=metadata
|
||||
)
|
||||
|
||||
instruct_response = InstructResponse()
|
||||
operation.response.Unpack(instruct_response)
|
||||
text = instruct_response.alternatives[0].text
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
|
||||
@@ -17,9 +17,12 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
return True
|
||||
|
||||
parser: BaseOutputParser[T]
|
||||
"""The parser to use to parse the output."""
|
||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||
retry_chain: Any
|
||||
"""The LLMChain to use to retry the completion."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@@ -35,7 +38,7 @@ class OutputFixingParser(BaseOutputParser[T]):
|
||||
llm: llm to use for fixing
|
||||
parser: parser to use for parsing
|
||||
prompt: prompt to use for fixing
|
||||
max_retries: Maximum number of retries to parser.
|
||||
max_retries: Maximum number of retries to parse.
|
||||
|
||||
Returns:
|
||||
OutputFixingParser
|
||||
|
||||
@@ -48,6 +48,8 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||
retry_chain: Any
|
||||
"""The LLMChain to use to retry the completion."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@@ -55,11 +57,23 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
llm: BaseLanguageModel,
|
||||
parser: BaseOutputParser[T],
|
||||
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
|
||||
max_retries: int = 1,
|
||||
) -> RetryOutputParser[T]:
|
||||
"""Create an OutputFixingParser from a language model and a parser.
|
||||
|
||||
Args:
|
||||
llm: llm to use for fixing
|
||||
parser: parser to use for parsing
|
||||
prompt: prompt to use for fixing
|
||||
max_retries: Maximum number of retries to parse.
|
||||
|
||||
Returns:
|
||||
RetryOutputParser
|
||||
"""
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(parser=parser, retry_chain=chain)
|
||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
"""Parse the output of an LLM call using a wrapped parser.
|
||||
@@ -71,15 +85,21 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
Returns:
|
||||
The parsed completion.
|
||||
"""
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException:
|
||||
new_completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(), completion=completion
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
retries = 0
|
||||
|
||||
return parsed_completion
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(), completion=completion
|
||||
)
|
||||
|
||||
raise OutputParserException("Failed to parse")
|
||||
|
||||
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
"""Parse the output of an LLM call using a wrapped parser.
|
||||
@@ -91,15 +111,21 @@ class RetryOutputParser(BaseOutputParser[T]):
|
||||
Returns:
|
||||
The parsed completion.
|
||||
"""
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException:
|
||||
new_completion = await self.retry_chain.arun(
|
||||
prompt=prompt_value.to_string(), completion=completion
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
retries = 0
|
||||
|
||||
return parsed_completion
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return await self.parser.aparse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
completion = await self.retry_chain.arun(
|
||||
prompt=prompt_value.to_string(), completion=completion
|
||||
)
|
||||
|
||||
raise OutputParserException("Failed to parse")
|
||||
|
||||
def parse(self, completion: str) -> T:
|
||||
raise NotImplementedError(
|
||||
@@ -125,8 +151,12 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
"""
|
||||
|
||||
parser: BaseOutputParser[T]
|
||||
"""The parser to use to parse the output."""
|
||||
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
|
||||
retry_chain: Any
|
||||
"""The LLMChain to use to retry the completion."""
|
||||
max_retries: int = 1
|
||||
"""The maximum number of times to retry the parse."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
@@ -134,6 +164,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
llm: BaseLanguageModel,
|
||||
parser: BaseOutputParser[T],
|
||||
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
|
||||
max_retries: int = 1,
|
||||
) -> RetryWithErrorOutputParser[T]:
|
||||
"""Create a RetryWithErrorOutputParser from an LLM.
|
||||
|
||||
@@ -141,6 +172,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
llm: The LLM to use to retry the completion.
|
||||
parser: The parser to use to parse the output.
|
||||
prompt: The prompt to use to retry the completion.
|
||||
max_retries: The maximum number of times to retry the completion.
|
||||
|
||||
Returns:
|
||||
A RetryWithErrorOutputParser.
|
||||
@@ -148,29 +180,45 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(parser=parser, retry_chain=chain)
|
||||
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
|
||||
|
||||
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
new_completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
retries = 0
|
||||
|
||||
return parsed_completion
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
completion = self.retry_chain.run(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
|
||||
raise OutputParserException("Failed to parse")
|
||||
|
||||
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
|
||||
try:
|
||||
parsed_completion = self.parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
new_completion = await self.retry_chain.arun(
|
||||
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
|
||||
)
|
||||
parsed_completion = self.parser.parse(new_completion)
|
||||
retries = 0
|
||||
|
||||
return parsed_completion
|
||||
while retries <= self.max_retries:
|
||||
try:
|
||||
return await self.parser.aparse(completion)
|
||||
except OutputParserException as e:
|
||||
if retries == self.max_retries:
|
||||
raise e
|
||||
else:
|
||||
retries += 1
|
||||
completion = await self.retry_chain.arun(
|
||||
prompt=prompt_value.to_string(),
|
||||
completion=completion,
|
||||
error=repr(e),
|
||||
)
|
||||
|
||||
raise OutputParserException("Failed to parse")
|
||||
|
||||
def parse(self, completion: str) -> T:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -32,10 +32,8 @@ from langchain.retrievers.ensemble import EnsembleRetriever
|
||||
from langchain.retrievers.google_cloud_documentai_warehouse import (
|
||||
GoogleDocumentAIWarehouseRetriever,
|
||||
)
|
||||
from langchain.retrievers.google_cloud_enterprise_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
)
|
||||
from langchain.retrievers.google_vertex_ai_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
GoogleVertexAIMultiTurnSearchRetriever,
|
||||
GoogleVertexAISearchRetriever,
|
||||
)
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
"""Retriever wrapper for Google Vertex AI Search.
|
||||
DEPRECATED: Maintained for backwards compatibility.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from langchain.retrievers.google_vertex_ai_search import GoogleVertexAISearchRetriever
|
||||
|
||||
|
||||
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
|
||||
"""`Google Vertex Search API` retriever alias for backwards compatibility.
|
||||
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
super().__init__(**data)
|
||||
@@ -25,10 +25,21 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
|
||||
"""Vertex AI Search data store ID."""
|
||||
location_id: str = "global"
|
||||
"""Vertex AI Search data store location."""
|
||||
serving_config_id: str = "default_config"
|
||||
"""Vertex AI Search serving config ID."""
|
||||
credentials: Any = None
|
||||
"""The default custom credentials (google.auth.credentials.Credentials) to use
|
||||
when making API calls. If not provided, credentials will be ascertained from
|
||||
the environment."""
|
||||
engine_data_type: int = Field(default=0, ge=0, le=2)
|
||||
""" Defines the Vertex AI Search data type
|
||||
0 - Unstructured data
|
||||
1 - Structured data
|
||||
2 - Website data (with Advanced Website Indexing)
|
||||
"""
|
||||
|
||||
_serving_config: str
|
||||
"""Full path of serving config."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@@ -144,6 +155,47 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
|
||||
|
||||
return documents
|
||||
|
||||
def _convert_website_search_response(
|
||||
self, results: Sequence[SearchResult]
|
||||
) -> List[Document]:
|
||||
"""Converts a sequence of search results to a list of LangChain documents."""
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
for result in results:
|
||||
document_dict = MessageToDict(
|
||||
result.document._pb, preserving_proto_field_name=True
|
||||
)
|
||||
derived_struct_data = document_dict.get("derived_struct_data")
|
||||
if not derived_struct_data:
|
||||
continue
|
||||
|
||||
doc_metadata = document_dict.get("struct_data", {})
|
||||
doc_metadata["id"] = document_dict["id"]
|
||||
doc_metadata["source"] = derived_struct_data.get("link", "")
|
||||
|
||||
chunk_type = "extractive_answers"
|
||||
|
||||
if chunk_type not in derived_struct_data:
|
||||
continue
|
||||
|
||||
for chunk in derived_struct_data[chunk_type]:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=chunk.get("content", ""), metadata=doc_metadata
|
||||
)
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print(
|
||||
f"No {chunk_type} could be found.\n"
|
||||
"Make sure that your data store is using Advanced Website Indexing.\n"
|
||||
"https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing" # noqa: E501
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
|
||||
"""`Google Vertex AI Search` retriever.
|
||||
@@ -153,8 +205,6 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
|
||||
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
|
||||
"""
|
||||
|
||||
serving_config_id: str = "default_config"
|
||||
"""Vertex AI Search serving config ID."""
|
||||
filter: Optional[str] = None
|
||||
"""Filter expression."""
|
||||
get_extractive_answers: bool = False
|
||||
@@ -188,15 +238,7 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
|
||||
Search will be based on the corrected query if found.
|
||||
"""
|
||||
|
||||
# TODO: Add extra data type handling for type website
|
||||
engine_data_type: int = Field(default=0, ge=0, le=1)
|
||||
""" Defines the Vertex AI Search data type
|
||||
0 - Unstructured data
|
||||
1 - Structured data
|
||||
"""
|
||||
|
||||
_client: SearchServiceClient
|
||||
_serving_config: str
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -260,11 +302,16 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
|
||||
)
|
||||
elif self.engine_data_type == 1:
|
||||
content_search_spec = None
|
||||
elif self.engine_data_type == 2:
|
||||
content_search_spec = SearchRequest.ContentSearchSpec(
|
||||
extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
|
||||
max_extractive_answer_count=self.max_extractive_answer_count,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# TODO: Add extra data type handling for type website
|
||||
raise NotImplementedError(
|
||||
"Only engine data type 0 (Unstructured) or 1 (Structured)"
|
||||
+ " are supported currently."
|
||||
"Only data store type 0 (Unstructured), 1 (Structured),"
|
||||
"or 2 (Website with Advanced Indexing) are supported currently."
|
||||
+ f" Got {self.engine_data_type}"
|
||||
)
|
||||
|
||||
@@ -305,11 +352,12 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
|
||||
)
|
||||
elif self.engine_data_type == 1:
|
||||
documents = self._convert_structured_search_response(response.results)
|
||||
elif self.engine_data_type == 2:
|
||||
documents = self._convert_website_search_response(response.results)
|
||||
else:
|
||||
# TODO: Add extra data type handling for type website
|
||||
raise NotImplementedError(
|
||||
"Only engine data type 0 (Unstructured) or 1 (Structured)"
|
||||
+ " are supported currently."
|
||||
"Only data store type 0 (Unstructured), 1 (Structured),"
|
||||
"or 2 (Website with Advanced Indexing) are supported currently."
|
||||
+ f" Got {self.engine_data_type}"
|
||||
)
|
||||
|
||||
@@ -321,6 +369,9 @@ class GoogleVertexAIMultiTurnSearchRetriever(
|
||||
):
|
||||
"""`Google Vertex AI Search` retriever for multi-turn conversations."""
|
||||
|
||||
conversation_id: str = "-"
|
||||
"""Vertex AI Search Conversation ID."""
|
||||
|
||||
_client: ConversationalSearchServiceClient
|
||||
|
||||
class Config:
|
||||
@@ -340,6 +391,20 @@ class GoogleVertexAIMultiTurnSearchRetriever(
|
||||
credentials=self.credentials, client_options=self.client_options
|
||||
)
|
||||
|
||||
self._serving_config = self._client.serving_config_path(
|
||||
project=self.project_id,
|
||||
location=self.location_id,
|
||||
data_store=self.data_store_id,
|
||||
serving_config=self.serving_config_id,
|
||||
)
|
||||
|
||||
if self.engine_data_type == 1:
|
||||
raise NotImplementedError(
|
||||
"Data store type 1 (Structured)"
|
||||
"is not currently supported for multi-turn search."
|
||||
+ f" Got {self.engine_data_type}"
|
||||
)
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
@@ -351,11 +416,35 @@ class GoogleVertexAIMultiTurnSearchRetriever(
|
||||
|
||||
request = ConverseConversationRequest(
|
||||
name=self._client.conversation_path(
|
||||
self.project_id, self.location_id, self.data_store_id, "-"
|
||||
self.project_id,
|
||||
self.location_id,
|
||||
self.data_store_id,
|
||||
self.conversation_id,
|
||||
),
|
||||
serving_config=self._serving_config,
|
||||
query=TextInput(input=query),
|
||||
)
|
||||
response = self._client.converse_conversation(request)
|
||||
|
||||
if self.engine_data_type == 2:
|
||||
return self._convert_website_search_response(response.search_results)
|
||||
|
||||
return self._convert_unstructured_search_response(
|
||||
response.search_results, "extractive_answers"
|
||||
)
|
||||
|
||||
|
||||
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
|
||||
"""`Google Vertex Search API` retriever alias for backwards compatibility.
|
||||
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -39,7 +39,7 @@ class ElasticsearchTranslator(Visitor):
|
||||
Comparator.LT: "lt",
|
||||
Comparator.LTE: "lte",
|
||||
Comparator.CONTAIN: "match",
|
||||
Comparator.LIKE: "fuzzy",
|
||||
Comparator.LIKE: "match",
|
||||
}
|
||||
return map_dict[func]
|
||||
|
||||
@@ -67,15 +67,19 @@ class ElasticsearchTranslator(Visitor):
|
||||
}
|
||||
}
|
||||
|
||||
if comparison.comparator == Comparator.LIKE:
|
||||
if comparison.comparator == Comparator.CONTAIN:
|
||||
return {
|
||||
self._format_func(comparison.comparator): {
|
||||
field: {"value": comparison.value, "fuzziness": "AUTO"}
|
||||
field: {"query": comparison.value}
|
||||
}
|
||||
}
|
||||
|
||||
if comparison.comparator == Comparator.CONTAIN:
|
||||
return {self._format_func(comparison.comparator): {field: comparison.value}}
|
||||
if comparison.comparator == Comparator.LIKE:
|
||||
return {
|
||||
self._format_func(comparison.comparator): {
|
||||
field: {"query": comparison.value, "fuzziness": "AUTO"}
|
||||
}
|
||||
}
|
||||
|
||||
# we assume that if the value is a string,
|
||||
# we want to use the keyword field
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, ValidationError
|
||||
from langchain.pydantic_v1 import ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,288 +38,278 @@ class HTTPVerb(str, Enum):
|
||||
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
|
||||
|
||||
|
||||
if _PYDANTIC_MAJOR_VERSION == 1:
|
||||
if TYPE_CHECKING:
|
||||
from openapi_schema_pydantic import (
|
||||
Components,
|
||||
Operation,
|
||||
Parameter,
|
||||
PathItem,
|
||||
Paths,
|
||||
Reference,
|
||||
RequestBody,
|
||||
Schema,
|
||||
if TYPE_CHECKING:
|
||||
from openapi_pydantic import (
|
||||
Components,
|
||||
Operation,
|
||||
Parameter,
|
||||
PathItem,
|
||||
Paths,
|
||||
Reference,
|
||||
RequestBody,
|
||||
Schema,
|
||||
)
|
||||
|
||||
try:
|
||||
from openapi_pydantic import OpenAPI
|
||||
except ImportError:
|
||||
OpenAPI = object # type: ignore
|
||||
|
||||
|
||||
class OpenAPISpec(OpenAPI):
|
||||
"""OpenAPI Model that removes mis-formatted parts of the spec."""
|
||||
|
||||
openapi: str = "3.1.0" # overriding overly restrictive type from parent class
|
||||
|
||||
@property
|
||||
def _paths_strict(self) -> Paths:
|
||||
if not self.paths:
|
||||
raise ValueError("No paths found in spec")
|
||||
return self.paths
|
||||
|
||||
def _get_path_strict(self, path: str) -> PathItem:
|
||||
path_item = self._paths_strict.get(path)
|
||||
if not path_item:
|
||||
raise ValueError(f"No path found for {path}")
|
||||
return path_item
|
||||
|
||||
@property
|
||||
def _components_strict(self) -> Components:
|
||||
"""Get components or err."""
|
||||
if self.components is None:
|
||||
raise ValueError("No components found in spec. ")
|
||||
return self.components
|
||||
|
||||
@property
|
||||
def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
|
||||
"""Get parameters or err."""
|
||||
parameters = self._components_strict.parameters
|
||||
if parameters is None:
|
||||
raise ValueError("No parameters found in spec. ")
|
||||
return parameters
|
||||
|
||||
@property
|
||||
def _schemas_strict(self) -> Dict[str, Schema]:
|
||||
"""Get the dictionary of schemas or err."""
|
||||
schemas = self._components_strict.schemas
|
||||
if schemas is None:
|
||||
raise ValueError("No schemas found in spec. ")
|
||||
return schemas
|
||||
|
||||
@property
|
||||
def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]:
|
||||
"""Get the request body or err."""
|
||||
request_bodies = self._components_strict.requestBodies
|
||||
if request_bodies is None:
|
||||
raise ValueError("No request body found in spec. ")
|
||||
return request_bodies
|
||||
|
||||
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
|
||||
"""Get a parameter (or nested reference) or err."""
|
||||
ref_name = ref.ref.split("/")[-1]
|
||||
parameters = self._parameters_strict
|
||||
if ref_name not in parameters:
|
||||
raise ValueError(f"No parameter found for {ref_name}")
|
||||
return parameters[ref_name]
|
||||
|
||||
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
|
||||
"""Get the root reference or err."""
|
||||
from openapi_pydantic import Reference
|
||||
|
||||
parameter = self._get_referenced_parameter(ref)
|
||||
while isinstance(parameter, Reference):
|
||||
parameter = self._get_referenced_parameter(parameter)
|
||||
return parameter
|
||||
|
||||
def get_referenced_schema(self, ref: Reference) -> Schema:
|
||||
"""Get a schema (or nested reference) or err."""
|
||||
ref_name = ref.ref.split("/")[-1]
|
||||
schemas = self._schemas_strict
|
||||
if ref_name not in schemas:
|
||||
raise ValueError(f"No schema found for {ref_name}")
|
||||
return schemas[ref_name]
|
||||
|
||||
def get_schema(self, schema: Union[Reference, Schema]) -> Schema:
|
||||
from openapi_pydantic import Reference
|
||||
|
||||
if isinstance(schema, Reference):
|
||||
return self.get_referenced_schema(schema)
|
||||
return schema
|
||||
|
||||
def _get_root_referenced_schema(self, ref: Reference) -> Schema:
|
||||
"""Get the root reference or err."""
|
||||
from openapi_pydantic import Reference
|
||||
|
||||
schema = self.get_referenced_schema(ref)
|
||||
while isinstance(schema, Reference):
|
||||
schema = self.get_referenced_schema(schema)
|
||||
return schema
|
||||
|
||||
def _get_referenced_request_body(
|
||||
self, ref: Reference
|
||||
) -> Optional[Union[Reference, RequestBody]]:
|
||||
"""Get a request body (or nested reference) or err."""
|
||||
ref_name = ref.ref.split("/")[-1]
|
||||
request_bodies = self._request_bodies_strict
|
||||
if ref_name not in request_bodies:
|
||||
raise ValueError(f"No request body found for {ref_name}")
|
||||
return request_bodies[ref_name]
|
||||
|
||||
def _get_root_referenced_request_body(
|
||||
self, ref: Reference
|
||||
) -> Optional[RequestBody]:
|
||||
"""Get the root request Body or err."""
|
||||
from openapi_pydantic import Reference
|
||||
|
||||
request_body = self._get_referenced_request_body(ref)
|
||||
while isinstance(request_body, Reference):
|
||||
request_body = self._get_referenced_request_body(request_body)
|
||||
return request_body
|
||||
|
||||
@staticmethod
|
||||
def _alert_unsupported_spec(obj: dict) -> None:
|
||||
"""Alert if the spec is not supported."""
|
||||
warning_message = (
|
||||
" This may result in degraded performance."
|
||||
+ " Convert your OpenAPI spec to 3.1.* spec"
|
||||
+ " for better support."
|
||||
)
|
||||
|
||||
try:
|
||||
from openapi_schema_pydantic import OpenAPI
|
||||
except ImportError:
|
||||
OpenAPI = object # type: ignore
|
||||
|
||||
class OpenAPISpec(OpenAPI):
|
||||
"""OpenAPI Model that removes mis-formatted parts of the spec."""
|
||||
|
||||
@property
|
||||
def _paths_strict(self) -> Paths:
|
||||
if not self.paths:
|
||||
raise ValueError("No paths found in spec")
|
||||
return self.paths
|
||||
|
||||
def _get_path_strict(self, path: str) -> PathItem:
|
||||
path_item = self._paths_strict.get(path)
|
||||
if not path_item:
|
||||
raise ValueError(f"No path found for {path}")
|
||||
return path_item
|
||||
|
||||
@property
|
||||
def _components_strict(self) -> Components:
|
||||
"""Get components or err."""
|
||||
if self.components is None:
|
||||
raise ValueError("No components found in spec. ")
|
||||
return self.components
|
||||
|
||||
@property
|
||||
def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
|
||||
"""Get parameters or err."""
|
||||
parameters = self._components_strict.parameters
|
||||
if parameters is None:
|
||||
raise ValueError("No parameters found in spec. ")
|
||||
return parameters
|
||||
|
||||
@property
|
||||
def _schemas_strict(self) -> Dict[str, Schema]:
|
||||
"""Get the dictionary of schemas or err."""
|
||||
schemas = self._components_strict.schemas
|
||||
if schemas is None:
|
||||
raise ValueError("No schemas found in spec. ")
|
||||
return schemas
|
||||
|
||||
@property
|
||||
def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]:
|
||||
"""Get the request body or err."""
|
||||
request_bodies = self._components_strict.requestBodies
|
||||
if request_bodies is None:
|
||||
raise ValueError("No request body found in spec. ")
|
||||
return request_bodies
|
||||
|
||||
def _get_referenced_parameter(
|
||||
self, ref: Reference
|
||||
) -> Union[Parameter, Reference]:
|
||||
"""Get a parameter (or nested reference) or err."""
|
||||
ref_name = ref.ref.split("/")[-1]
|
||||
parameters = self._parameters_strict
|
||||
if ref_name not in parameters:
|
||||
raise ValueError(f"No parameter found for {ref_name}")
|
||||
return parameters[ref_name]
|
||||
|
||||
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
|
||||
"""Get the root reference or err."""
|
||||
from openapi_schema_pydantic import Reference
|
||||
|
||||
parameter = self._get_referenced_parameter(ref)
|
||||
while isinstance(parameter, Reference):
|
||||
parameter = self._get_referenced_parameter(parameter)
|
||||
return parameter
|
||||
|
||||
def get_referenced_schema(self, ref: Reference) -> Schema:
|
||||
"""Get a schema (or nested reference) or err."""
|
||||
ref_name = ref.ref.split("/")[-1]
|
||||
schemas = self._schemas_strict
|
||||
if ref_name not in schemas:
|
||||
raise ValueError(f"No schema found for {ref_name}")
|
||||
return schemas[ref_name]
|
||||
|
||||
def get_schema(self, schema: Union[Reference, Schema]) -> Schema:
|
||||
from openapi_schema_pydantic import Reference
|
||||
|
||||
if isinstance(schema, Reference):
|
||||
return self.get_referenced_schema(schema)
|
||||
return schema
|
||||
|
||||
def _get_root_referenced_schema(self, ref: Reference) -> Schema:
|
||||
"""Get the root reference or err."""
|
||||
from openapi_schema_pydantic import Reference
|
||||
|
||||
schema = self.get_referenced_schema(ref)
|
||||
while isinstance(schema, Reference):
|
||||
schema = self.get_referenced_schema(schema)
|
||||
return schema
|
||||
|
||||
def _get_referenced_request_body(
|
||||
self, ref: Reference
|
||||
) -> Optional[Union[Reference, RequestBody]]:
|
||||
"""Get a request body (or nested reference) or err."""
|
||||
ref_name = ref.ref.split("/")[-1]
|
||||
request_bodies = self._request_bodies_strict
|
||||
if ref_name not in request_bodies:
|
||||
raise ValueError(f"No request body found for {ref_name}")
|
||||
return request_bodies[ref_name]
|
||||
|
||||
def _get_root_referenced_request_body(
|
||||
self, ref: Reference
|
||||
) -> Optional[RequestBody]:
|
||||
"""Get the root request Body or err."""
|
||||
from openapi_schema_pydantic import Reference
|
||||
|
||||
request_body = self._get_referenced_request_body(ref)
|
||||
while isinstance(request_body, Reference):
|
||||
request_body = self._get_referenced_request_body(request_body)
|
||||
return request_body
|
||||
|
||||
@staticmethod
|
||||
def _alert_unsupported_spec(obj: dict) -> None:
|
||||
"""Alert if the spec is not supported."""
|
||||
warning_message = (
|
||||
" This may result in degraded performance."
|
||||
+ " Convert your OpenAPI spec to 3.1.* spec"
|
||||
+ " for better support."
|
||||
)
|
||||
swagger_version = obj.get("swagger")
|
||||
openapi_version = obj.get("openapi")
|
||||
if isinstance(openapi_version, str):
|
||||
if openapi_version != "3.1.0":
|
||||
logger.warning(
|
||||
f"Attempting to load an OpenAPI {openapi_version}"
|
||||
f" spec. {warning_message}"
|
||||
)
|
||||
else:
|
||||
pass
|
||||
elif isinstance(swagger_version, str):
|
||||
swagger_version = obj.get("swagger")
|
||||
openapi_version = obj.get("openapi")
|
||||
if isinstance(openapi_version, str):
|
||||
if openapi_version != "3.1.0":
|
||||
logger.warning(
|
||||
f"Attempting to load a Swagger {swagger_version}"
|
||||
f"Attempting to load an OpenAPI {openapi_version}"
|
||||
f" spec. {warning_message}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Attempting to load an unsupported spec:"
|
||||
f"\n\n{obj}\n{warning_message}"
|
||||
)
|
||||
pass
|
||||
elif isinstance(swagger_version, str):
|
||||
logger.warning(
|
||||
f"Attempting to load a Swagger {swagger_version}"
|
||||
f" spec. {warning_message}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Attempting to load an unsupported spec:"
|
||||
f"\n\n{obj}\n{warning_message}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_obj(cls, obj: dict) -> OpenAPISpec:
|
||||
try:
|
||||
cls._alert_unsupported_spec(obj)
|
||||
return super().parse_obj(obj)
|
||||
except ValidationError as e:
|
||||
# We are handling possibly misconfigured specs and
|
||||
# want to do a best-effort job to get a reasonable interface out of it.
|
||||
new_obj = copy.deepcopy(obj)
|
||||
for error in e.errors():
|
||||
keys = error["loc"]
|
||||
item = new_obj
|
||||
for key in keys[:-1]:
|
||||
item = item[key]
|
||||
item.pop(keys[-1], None)
|
||||
return cls.parse_obj(new_obj)
|
||||
@classmethod
|
||||
def parse_obj(cls, obj: dict) -> OpenAPISpec:
|
||||
try:
|
||||
cls._alert_unsupported_spec(obj)
|
||||
return super().parse_obj(obj)
|
||||
except ValidationError as e:
|
||||
# We are handling possibly misconfigured specs and
|
||||
# want to do a best-effort job to get a reasonable interface out of it.
|
||||
new_obj = copy.deepcopy(obj)
|
||||
for error in e.errors():
|
||||
keys = error["loc"]
|
||||
item = new_obj
|
||||
for key in keys[:-1]:
|
||||
item = item[key]
|
||||
item.pop(keys[-1], None)
|
||||
return cls.parse_obj(new_obj)
|
||||
|
||||
@classmethod
|
||||
def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a dict."""
|
||||
return cls.parse_obj(spec_dict)
|
||||
@classmethod
|
||||
def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a dict."""
|
||||
return cls.parse_obj(spec_dict)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a text."""
|
||||
try:
|
||||
spec_dict = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
spec_dict = yaml.safe_load(text)
|
||||
return cls.from_spec_dict(spec_dict)
|
||||
@classmethod
|
||||
def from_text(cls, text: str) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a text."""
|
||||
try:
|
||||
spec_dict = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
spec_dict = yaml.safe_load(text)
|
||||
return cls.from_spec_dict(spec_dict)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a file path."""
|
||||
path_ = path if isinstance(path, Path) else Path(path)
|
||||
if not path_.exists():
|
||||
raise FileNotFoundError(f"{path} does not exist")
|
||||
with path_.open("r") as f:
|
||||
return cls.from_text(f.read())
|
||||
@classmethod
|
||||
def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a file path."""
|
||||
path_ = path if isinstance(path, Path) else Path(path)
|
||||
if not path_.exists():
|
||||
raise FileNotFoundError(f"{path} does not exist")
|
||||
with path_.open("r") as f:
|
||||
return cls.from_text(f.read())
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a URL."""
|
||||
response = requests.get(url)
|
||||
return cls.from_text(response.text)
|
||||
@classmethod
|
||||
def from_url(cls, url: str) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a URL."""
|
||||
response = requests.get(url)
|
||||
return cls.from_text(response.text)
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
"""Get the base url."""
|
||||
return self.servers[0].url
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
"""Get the base url."""
|
||||
return self.servers[0].url
|
||||
|
||||
def get_methods_for_path(self, path: str) -> List[str]:
|
||||
"""Return a list of valid methods for the specified path."""
|
||||
from openapi_schema_pydantic import Operation
|
||||
def get_methods_for_path(self, path: str) -> List[str]:
|
||||
"""Return a list of valid methods for the specified path."""
|
||||
from openapi_pydantic import Operation
|
||||
|
||||
path_item = self._get_path_strict(path)
|
||||
results = []
|
||||
for method in HTTPVerb:
|
||||
operation = getattr(path_item, method.value, None)
|
||||
if isinstance(operation, Operation):
|
||||
results.append(method.value)
|
||||
return results
|
||||
path_item = self._get_path_strict(path)
|
||||
results = []
|
||||
for method in HTTPVerb:
|
||||
operation = getattr(path_item, method.value, None)
|
||||
if isinstance(operation, Operation):
|
||||
results.append(method.value)
|
||||
return results
|
||||
|
||||
def get_parameters_for_path(self, path: str) -> List[Parameter]:
|
||||
from openapi_schema_pydantic import Reference
|
||||
def get_parameters_for_path(self, path: str) -> List[Parameter]:
|
||||
from openapi_pydantic import Reference
|
||||
|
||||
path_item = self._get_path_strict(path)
|
||||
parameters = []
|
||||
if not path_item.parameters:
|
||||
return []
|
||||
for parameter in path_item.parameters:
|
||||
path_item = self._get_path_strict(path)
|
||||
parameters = []
|
||||
if not path_item.parameters:
|
||||
return []
|
||||
for parameter in path_item.parameters:
|
||||
if isinstance(parameter, Reference):
|
||||
parameter = self._get_root_referenced_parameter(parameter)
|
||||
parameters.append(parameter)
|
||||
return parameters
|
||||
|
||||
def get_operation(self, path: str, method: str) -> Operation:
|
||||
"""Get the operation object for a given path and HTTP method."""
|
||||
from openapi_pydantic import Operation
|
||||
|
||||
path_item = self._get_path_strict(path)
|
||||
operation_obj = getattr(path_item, method, None)
|
||||
if not isinstance(operation_obj, Operation):
|
||||
raise ValueError(f"No {method} method found for {path}")
|
||||
return operation_obj
|
||||
|
||||
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
|
||||
"""Get the components for a given operation."""
|
||||
from openapi_pydantic import Reference
|
||||
|
||||
parameters = []
|
||||
if operation.parameters:
|
||||
for parameter in operation.parameters:
|
||||
if isinstance(parameter, Reference):
|
||||
parameter = self._get_root_referenced_parameter(parameter)
|
||||
parameters.append(parameter)
|
||||
return parameters
|
||||
return parameters
|
||||
|
||||
def get_operation(self, path: str, method: str) -> Operation:
|
||||
"""Get the operation object for a given path and HTTP method."""
|
||||
from openapi_schema_pydantic import Operation
|
||||
def get_request_body_for_operation(
|
||||
self, operation: Operation
|
||||
) -> Optional[RequestBody]:
|
||||
"""Get the request body for a given operation."""
|
||||
from openapi_pydantic import Reference
|
||||
|
||||
path_item = self._get_path_strict(path)
|
||||
operation_obj = getattr(path_item, method, None)
|
||||
if not isinstance(operation_obj, Operation):
|
||||
raise ValueError(f"No {method} method found for {path}")
|
||||
return operation_obj
|
||||
request_body = operation.requestBody
|
||||
if isinstance(request_body, Reference):
|
||||
request_body = self._get_root_referenced_request_body(request_body)
|
||||
return request_body
|
||||
|
||||
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
|
||||
"""Get the components for a given operation."""
|
||||
from openapi_schema_pydantic import Reference
|
||||
|
||||
parameters = []
|
||||
if operation.parameters:
|
||||
for parameter in operation.parameters:
|
||||
if isinstance(parameter, Reference):
|
||||
parameter = self._get_root_referenced_parameter(parameter)
|
||||
parameters.append(parameter)
|
||||
return parameters
|
||||
|
||||
def get_request_body_for_operation(
|
||||
self, operation: Operation
|
||||
) -> Optional[RequestBody]:
|
||||
"""Get the request body for a given operation."""
|
||||
from openapi_schema_pydantic import Reference
|
||||
|
||||
request_body = operation.requestBody
|
||||
if isinstance(request_body, Reference):
|
||||
request_body = self._get_root_referenced_request_body(request_body)
|
||||
return request_body
|
||||
|
||||
@staticmethod
|
||||
def get_cleaned_operation_id(
|
||||
operation: Operation, path: str, method: str
|
||||
) -> str:
|
||||
"""Get a cleaned operation id from an operation id."""
|
||||
operation_id = operation.operationId
|
||||
if operation_id is None:
|
||||
# Replace all punctuation of any kind with underscore
|
||||
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
|
||||
operation_id = f"{path}_{method}"
|
||||
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
|
||||
|
||||
else:
|
||||
|
||||
class OpenAPISpec: # type: ignore[no-redef]
|
||||
"""Shim for pydantic version >=2"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError("Only supported for pydantic version 1")
|
||||
@staticmethod
|
||||
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
|
||||
"""Get a cleaned operation id from an operation id."""
|
||||
operation_id = operation.operationId
|
||||
if operation_id is None:
|
||||
# Replace all punctuation of any kind with underscore
|
||||
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
|
||||
operation_id = f"{path}_{method}"
|
||||
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
|
||||
|
||||
@@ -156,12 +156,22 @@ class ElasticVectorSearch(VectorStore):
|
||||
self.index_name = index_name
|
||||
_ssl_verify = ssl_verify or {}
|
||||
try:
|
||||
self.client = elasticsearch.Elasticsearch(elasticsearch_url, **_ssl_verify)
|
||||
self.client = elasticsearch.Elasticsearch(
|
||||
elasticsearch_url,
|
||||
**_ssl_verify,
|
||||
headers={"user-agent": self.get_user_agent()},
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Your elasticsearch client string is mis-formatted. Got error: {e} "
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
from langchain import __version__
|
||||
|
||||
return f"langchain-py-dvs/{__version__}"
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
@@ -15,6 +16,7 @@ from typing import (
|
||||
Optional,
|
||||
Sized,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@@ -26,6 +28,8 @@ from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.vectorstore import VectorStore
|
||||
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||
"""
|
||||
@@ -82,7 +86,7 @@ class FAISS(VectorStore):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Callable,
|
||||
embedding_function: Union[Callable, Embeddings],
|
||||
index: Any,
|
||||
docstore: Docstore,
|
||||
index_to_docstore_id: Dict[int, str],
|
||||
@@ -91,6 +95,11 @@ class FAISS(VectorStore):
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
if not isinstance(embedding_function, Embeddings):
|
||||
logger.warning(
|
||||
"`embedding_function` is expected to be an Embeddings object, support "
|
||||
"for passing in a function will soon be removed."
|
||||
)
|
||||
self.embedding_function = embedding_function
|
||||
self.index = index
|
||||
self.docstore = docstore
|
||||
@@ -108,6 +117,26 @@ class FAISS(VectorStore):
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return (
|
||||
self.embedding_function
|
||||
if isinstance(self.embedding_function, Embeddings)
|
||||
else None
|
||||
)
|
||||
|
||||
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
if isinstance(self.embedding_function, Embeddings):
|
||||
return self.embedding_function.embed_documents(texts)
|
||||
else:
|
||||
return [self.embedding_function(text) for text in texts]
|
||||
|
||||
def _embed_query(self, text: str) -> List[float]:
|
||||
if isinstance(self.embedding_function, Embeddings):
|
||||
return self.embedding_function.embed_query(text)
|
||||
else:
|
||||
return self.embedding_function(text)
|
||||
|
||||
def __add(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@@ -163,7 +192,8 @@ class FAISS(VectorStore):
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
embeddings = [self.embedding_function(text) for text in texts]
|
||||
texts = list(texts)
|
||||
embeddings = self._embed_documents(texts)
|
||||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||
|
||||
def add_embeddings(
|
||||
@@ -272,7 +302,7 @@ class FAISS(VectorStore):
|
||||
List of documents most similar to the query text with
|
||||
L2 distance in float. Lower score represents more similarity.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
embedding = self._embed_query(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
@@ -465,7 +495,7 @@ class FAISS(VectorStore):
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
embedding = self._embed_query(query)
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
@@ -561,7 +591,7 @@ class FAISS(VectorStore):
|
||||
# Default to L2, currently other metric types not initialized.
|
||||
index = faiss.IndexFlatL2(len(embeddings[0]))
|
||||
vecstore = cls(
|
||||
embedding.embed_query,
|
||||
embedding,
|
||||
index,
|
||||
InMemoryDocstore(),
|
||||
{},
|
||||
@@ -696,9 +726,7 @@ class FAISS(VectorStore):
|
||||
# load docstore and index_to_docstore_id
|
||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
||||
docstore, index_to_docstore_id = pickle.load(f)
|
||||
return cls(
|
||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
||||
)
|
||||
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
|
||||
|
||||
def serialize_to_bytes(self) -> bytes:
|
||||
"""Serialize FAISS index, docstore, and index_to_docstore_id to bytes."""
|
||||
@@ -713,9 +741,7 @@ class FAISS(VectorStore):
|
||||
) -> FAISS:
|
||||
"""Deserialize FAISS index, docstore, and index_to_docstore_id from bytes."""
|
||||
index, docstore, index_to_docstore_id = pickle.loads(serialized)
|
||||
return cls(
|
||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
||||
)
|
||||
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
|
||||
@@ -209,6 +209,8 @@ class Weaviate(VectorStore):
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("tenant"):
|
||||
query_obj = query_obj.with_tenant(kwargs.get("tenant"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
result = query_obj.with_near_text(content).with_limit(k).do()
|
||||
@@ -228,6 +230,8 @@ class Weaviate(VectorStore):
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("tenant"):
|
||||
query_obj = query_obj.with_tenant(kwargs.get("tenant"))
|
||||
if kwargs.get("additional"):
|
||||
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||
result = query_obj.with_near_vector(vector).with_limit(k).do()
|
||||
@@ -304,6 +308,8 @@ class Weaviate(VectorStore):
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("tenant"):
|
||||
query_obj = query_obj.with_tenant(kwargs.get("tenant"))
|
||||
results = (
|
||||
query_obj.with_additional("vector")
|
||||
.with_near_vector(vector)
|
||||
@@ -343,6 +349,8 @@ class Weaviate(VectorStore):
|
||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||
if kwargs.get("where_filter"):
|
||||
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||
if kwargs.get("tenant"):
|
||||
query_obj = query_obj.with_tenant(kwargs.get("tenant"))
|
||||
|
||||
embedded_query = self._embedding.embed_query(query)
|
||||
if not self._by_text:
|
||||
|
||||
2176
libs/langchain/poetry.lock
generated
2176
libs/langchain/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain"
|
||||
version = "0.0.317"
|
||||
version = "0.0.318"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
@@ -20,7 +20,7 @@ PyYAML = ">=5.3"
|
||||
numpy = "^1"
|
||||
azure-core = {version = "^1.26.4", optional=true}
|
||||
tqdm = {version = ">=4.48.0", optional = true}
|
||||
openapi-schema-pydantic = {version = "^1.2", optional = true}
|
||||
openapi-pydantic = {version = "^0.3.2", optional = true}
|
||||
faiss-cpu = {version = "^1", optional = true}
|
||||
wikipedia = {version = "^1", optional = true}
|
||||
elasticsearch = {version = "^8", optional = true}
|
||||
@@ -359,7 +359,7 @@ extended_testing = [
|
||||
"xata",
|
||||
"xmltodict",
|
||||
"faiss-cpu",
|
||||
"openapi-schema-pydantic",
|
||||
"openapi-pydantic",
|
||||
"markdownify",
|
||||
"arxiv",
|
||||
"dashvector",
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Test AliCloud Pai Eas Chat Model."""
|
||||
import os
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_pai_eas_call() -> None:
|
||||
chat = PaiEasChatEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
)
|
||||
response = chat(messages=[HumanMessage(content="Say foo:")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_history() -> None:
|
||||
"""Tests multiple history works."""
|
||||
chat = PaiEasChatEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
)
|
||||
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test that stream works."""
|
||||
chat = PaiEasChatEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
streaming=True,
|
||||
)
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="Who are you?"),
|
||||
],
|
||||
stream=True,
|
||||
callbacks=callback_manager,
|
||||
)
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
"""Tests multiple messages works."""
|
||||
chat = PaiEasChatEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
)
|
||||
message = HumanMessage(content="Hi, how are you.")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Test PaiEasEndpoint API wrapper."""
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
|
||||
|
||||
|
||||
def test_pai_eas_v1_call() -> None:
|
||||
"""Test valid call to PAI-EAS Service."""
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
version="1.0",
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_pai_eas_v2_call() -> None:
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
version="2.0",
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_pai_eas_v1_streaming() -> None:
|
||||
"""Test streaming call to PAI-EAS Service."""
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
version="1.0",
|
||||
)
|
||||
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
|
||||
stream_results_string = ""
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for chunk in generator:
|
||||
assert isinstance(chunk, str)
|
||||
stream_results_string = chunk
|
||||
assert len(stream_results_string.strip()) > 1
|
||||
|
||||
|
||||
def test_pai_eas_v2_streaming() -> None:
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
version="2.0",
|
||||
)
|
||||
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
|
||||
stream_results_string = ""
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for chunk in generator:
|
||||
assert isinstance(chunk, str)
|
||||
stream_results_string = chunk
|
||||
assert len(stream_results_string.strip()) > 1
|
||||
@@ -15,10 +15,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers.google_cloud_enterprise_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
)
|
||||
from langchain.retrievers.google_vertex_ai_search import (
|
||||
GoogleCloudEnterpriseSearchRetriever,
|
||||
GoogleVertexAIMultiTurnSearchRetriever,
|
||||
GoogleVertexAISearchRetriever,
|
||||
)
|
||||
|
||||
@@ -531,7 +531,7 @@ class TestElasticsearch:
|
||||
},
|
||||
}
|
||||
},
|
||||
settings={"index": {"default_pipeline": "pipeline"}},
|
||||
settings={"index": {"default_pipeline": "test_pipeline"}},
|
||||
)
|
||||
|
||||
# adding documents to the index
|
||||
|
||||
@@ -49,14 +49,14 @@ def test_visit_comparison_range_lte() -> None:
|
||||
|
||||
def test_visit_comparison_range_match() -> None:
|
||||
comp = Comparison(comparator=Comparator.CONTAIN, attribute="foo", value="1")
|
||||
expected = {"match": {"metadata.foo": "1"}}
|
||||
expected = {"match": {"metadata.foo": {"query": "1"}}}
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_comparison_range_like() -> None:
|
||||
comp = Comparison(comparator=Comparator.LIKE, attribute="foo", value="bar")
|
||||
expected = {"fuzzy": {"metadata.foo": {"value": "bar", "fuzziness": "AUTO"}}}
|
||||
expected = {"match": {"metadata.foo": {"query": "bar", "fuzziness": "AUTO"}}}
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
@@ -200,9 +200,9 @@ def test_visit_structured_query_complex() -> None:
|
||||
"should": [
|
||||
{"range": {"metadata.bar": {"lt": 1}}},
|
||||
{
|
||||
"fuzzy": {
|
||||
"match": {
|
||||
"metadata.bar": {
|
||||
"value": "10",
|
||||
"query": "10",
|
||||
"fuzziness": "AUTO",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2163,19 +2163,19 @@
|
||||
dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/HumanMessage',
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
'$ref': '#/definitions/HumanMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ChatMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/FunctionMessage',
|
||||
'$ref': '#/definitions/SystemMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/SystemMessage',
|
||||
'$ref': '#/definitions/FunctionMessage',
|
||||
}),
|
||||
]),
|
||||
'definitions': dict({
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Iterable, List, Tuple
|
||||
import pytest
|
||||
|
||||
# Keep at top of file to ensure that pydantic test can be skipped before
|
||||
# pydantic v1 related imports are attempted by openapi_schema_pydantic.
|
||||
# pydantic v1 related imports are attempted by openapi_pydantic.
|
||||
from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION
|
||||
|
||||
if _PYDANTIC_MAJOR_VERSION != 1:
|
||||
@@ -78,7 +78,7 @@ def http_paths_and_methods() -> List[Tuple[str, OpenAPISpec, str, str]]:
|
||||
return http_paths_and_methods
|
||||
|
||||
|
||||
@pytest.mark.requires("openapi_schema_pydantic")
|
||||
@pytest.mark.requires("openapi_pydantic")
|
||||
def test_parse_api_operations() -> None:
|
||||
"""Test the APIOperation class."""
|
||||
for spec_name, spec, path, method in http_paths_and_methods():
|
||||
@@ -88,21 +88,21 @@ def test_parse_api_operations() -> None:
|
||||
raise AssertionError(f"Error processing {spec_name}: {e} ") from e
|
||||
|
||||
|
||||
@pytest.mark.requires("openapi_schema_pydantic")
|
||||
@pytest.mark.requires("openapi_pydantic")
|
||||
@pytest.fixture
|
||||
def raw_spec() -> OpenAPISpec:
|
||||
"""Return a raw OpenAPI spec."""
|
||||
from openapi_schema_pydantic import Info
|
||||
from openapi_pydantic import Info
|
||||
|
||||
return OpenAPISpec(
|
||||
info=Info(title="Test API", version="1.0.0"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("openapi_schema_pydantic")
|
||||
@pytest.mark.requires("openapi_pydantic")
|
||||
def test_api_request_body_from_request_body_with_ref(raw_spec: OpenAPISpec) -> None:
|
||||
"""Test instantiating APIRequestBody from RequestBody with a reference."""
|
||||
from openapi_schema_pydantic import (
|
||||
from openapi_pydantic import (
|
||||
Components,
|
||||
MediaType,
|
||||
Reference,
|
||||
@@ -140,10 +140,10 @@ def test_api_request_body_from_request_body_with_ref(raw_spec: OpenAPISpec) -> N
|
||||
assert api_request_body.media_type == "application/json"
|
||||
|
||||
|
||||
@pytest.mark.requires("openapi_schema_pydantic")
|
||||
@pytest.mark.requires("openapi_pydantic")
|
||||
def test_api_request_body_from_request_body_with_schema(raw_spec: OpenAPISpec) -> None:
|
||||
"""Test instantiating APIRequestBody from RequestBody with a schema."""
|
||||
from openapi_schema_pydantic import (
|
||||
from openapi_pydantic import (
|
||||
MediaType,
|
||||
RequestBody,
|
||||
Schema,
|
||||
@@ -171,9 +171,9 @@ def test_api_request_body_from_request_body_with_schema(raw_spec: OpenAPISpec) -
|
||||
assert api_request_body.media_type == "application/json"
|
||||
|
||||
|
||||
@pytest.mark.requires("openapi_schema_pydantic")
|
||||
@pytest.mark.requires("openapi_pydantic")
|
||||
def test_api_request_body_property_from_schema(raw_spec: OpenAPISpec) -> None:
|
||||
from openapi_schema_pydantic import (
|
||||
from openapi_pydantic import (
|
||||
Components,
|
||||
Reference,
|
||||
Schema,
|
||||
|
||||
Reference in New Issue
Block a user