Compare commits

..

19 Commits

Author SHA1 Message Date
Bagatur
76d3afaef0 bump 318 (#12030) 2023-10-19 09:33:39 -07:00
Dmitry Tyumentsev
5dd2161c4b add _acall method to YandexGPT (#12029)
- **Description:** Add async support for YandexGPT LLM model

Co-authored-by: Dmitry Tyumentsev <dmitry.tyumentsev@raftds.com>
2023-10-19 09:15:26 -07:00
Palau
720ecacb1c Add notebook for kay.ai press release data (#11575)
- **Description:** Adding a notebook for Press Release data from Kay.ai,
as discussed offline
  - **Tag maintainer:** @baskaryan @hwchase17 
- **Twitter handle:** https://twitter.com/kaydotai
https://twitter.com/vishalrohra_

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-10-19 08:06:56 -07:00
Peter Krenesky
8425f33363 Pydantic v2 support for OpenAPI Specs (#11936)
- **Description:** Adding Pydantic v2 support for OpenAPI Specs 

- **Issue:**
- OpenAPI spec support was disabled because `openapi-schema-pydantic`
doesn't support Pydantic v2:
     #9205
     
     - Caused errors in `get_openapi_chain`
   
    - This may be the cause of #9520.

- **Tag maintainer:** @eyurtsev
- **Twitter handle:** kreneskyp


The root cause was that `openapi-schema-pydantic` hasn't been updated in
some time but
[openapi-pydantic](https://github.com/mike-oakley/openapi-pydantic)
forked and updated the project.
2023-10-19 11:06:11 -04:00
volodymyr-memsql
4adabd33ac Add example of retriever usage with SingleStoreDB vector store (#12021)
Added a notebook with examples of the creation of a retriever from the
SingleStoreDB vector store, and further usage.

Co-authored-by: Volodymyr Tkachuk <vtkachuk-ua@singlestore.com>
2023-10-19 09:48:35 -04:00
Joe McElroy
c9f1768cb9 Elasticsearch Query Retriever: Use match + fuzziness for LIKE (#12023)
Updated the elasticsearch self query retriever to use the match clause
for LIKE operator instead of the non-analyzed fuzzy search clause.

Other small updates include:
- fixing the stack inference integration test where the index's default
pipeline didn't use the inference pipeline created
- adding a user-agent to the old implementation to track usage
- improved the documentation for ElasticsearchStore filters
2023-10-19 09:47:21 -04:00
maks-operlejn-ds
84d250f781 Docs: QA Privacy Nit (#12025)
Resize image in docs for QA Privacy
2023-10-19 09:43:47 -04:00
Nuno Campos
7db6aabf65 Update chat model output type (#11833)
---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-10-19 00:55:15 -07:00
Simon Dai
ed62984cb2 update Weaviate to support multi tenancy (#11842)
- **Description:** update Weaviate to support multi tenancy
  - **Issue:** 9956
  - **Dependencies:** 
  - **Tag maintainer:** hwchase17
  - **Twitter handle:** dsx1986_
2023-10-19 00:49:30 -07:00
hiigao
f818ec49b8 Encapsulate alicloud pai-eas access method for chatmodels and llms (#11852)
### Description: 
To provide an eas llm service access methods in this pull request by
impletementing `PaiEasEndpoint` and `PaiEasChatEndpoint` classes in
`langchain.llms` and `langchain.chat_models` modules. Base on this pr,
langchain users can build up a chain to call remote eas llm service and
get the llm inference results.

### About EAS Service
EAS is a Alicloud product on Alibaba Cloud Machine Learning Platform for
AI which is short for AliCloud PAI. EAS provides model inference
deployment services for the users. We build up a llm inference services
on EAS with a general llm docker images. Therefore, end users can
quickly setup their llm remote instances to load majority of the
hugginface llm models, and serve as a backend for most of the llm apps.

### Dependencies
This pr does't involve any new dependencies.

---------

Co-authored-by: 子洪 <gaoyihong.gyh@alibaba-inc.com>
2023-10-19 00:20:18 -07:00
Shinya Maeda
1da6d92369 fix: superfluous List Parser doc (#12014) 2023-10-19 00:14:38 -07:00
John Mai
a6b483dcbc Supported RetryOutputParser & RetryWithErrorOutputParser max_retries (#11903)
Description: Supported RetryOutputParser & RetryWithErrorOutputParser
max_retries
- max_retries: Maximum number of retries to parser.

Issue: None
Dependencies: None
Tag maintainer: @baskaryan 
Twitter handle:
2023-10-18 23:57:16 -07:00
Hugues Chocart
008c7df80d [LLMonitorCallbackHandler] Refactor + add llmonitor-py dependency (#11948)
We now require uses to have the pip package `llmonitor` installed. It
allows us to have cleaner code and avoid duplicates between our library
and our code in Langchain.
2023-10-18 23:54:10 -07:00
Sian Cao
77fc2f7644 fix: impl missing embeddings method (#10823)
FAISS does not implement embeddings method and use embed_query to
embedding texts which is wrong for some embedding models.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2023-10-18 23:51:28 -07:00
Holt Skinner
2661dc94f3 feat: Google Vertex AI Search Retriever - Add support for Website Data Stores (#11736)
- Only works for Data stores with Advanced Website Indexing
-
https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features
- Minor restructuring - Follow up to #10513
- Remove outdated docs (readded in
https://github.com/langchain-ai/langchain/pull/11620)
  - Move legacy class into new py file to clean up the directory
- Shouldn't cause backwards compatibility issues as the import works the
same way for users
2023-10-18 23:41:48 -07:00
Shorthills AI
4b6fdd7bf0 Update modal.py (#11588)
feat: Raise KeyError when 'prompt' key is missing in JSON response

This commit updates the error handling in the code to raise a KeyError
when the 'prompt' key is not found in the JSON response. This change
makes the code more explicit about the nature of the error, helping to
improve clarity and debugging.

@baskaryan, @eyurtsev.
2023-10-18 23:40:37 -07:00
Surav Shrestha
2038c7fd5d fix typo in multi_language.ipynb (#12009)
exprience -> experience
2023-10-18 23:33:25 -07:00
William FH
dfb4baa3f9 Fix Fireworks Callbacks (#12003)
I may be missing something but it seems like we inappropriately overrode
the 'stream()' method, losing callbacks in the process. I don't think
(?) it gave us anything in this case to customize it here?

See new trace:

https://smith.langchain.com/public/fbb82825-3a16-446b-8207-35622358db3b/r

and confirmed it streams.

Also fixes the stopwords issues from #12000
2023-10-18 23:33:09 -07:00
Lance Martin
12f8e87a0e LLaMA2 SQL cookbook clean (#12007) 2023-10-18 21:16:58 -07:00
45 changed files with 3771 additions and 2513 deletions

View File

@@ -40,6 +40,7 @@ Notebook | Description
[openai_functions_retrieval_qa....](https://github.com/langchain-ai/langchain/tree/master/cookbook/openai_functions_retrieval_qa.ipynb) | Structure response output in a question answering system by incorporating openai functions into a retrieval pipeline.
[petting_zoo.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/petting_zoo.ipynb) | Create multi-agent simulations with simulated environments using the petting zoo library.
[plan_and_execute_agent.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/plan_and_execute_agent.ipynb) | Create plan-and-execute agents that accomplish objectives by planning tasks with a language model (llm) and executing them with a separate agent.
[press_releases.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/press_releases.ipynb) | Retrieve and query company press release data powered by [Kay.ai](https://kay.ai).
[program_aided_language_model.i...](https://github.com/langchain-ai/langchain/tree/master/cookbook/program_aided_language_model.ipynb) | Implement program-aided language models as described in the provided research paper.
[sales_agent_with_context.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/sales_agent_with_context.ipynb) | Implement a context-aware ai sales agent, salesgpt, that can have natural sales conversations, interact with other systems, and use a product knowledge base to discuss a company's offerings.
[self_query_hotel_search.ipynb](https://github.com/langchain-ai/langchain/tree/master/cookbook/self_query_hotel_search.ipynb) | Build a hotel room search feature with self-querying retrieval, using a specific hotel recommendation dataset.

View File

@@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "62ee82e4-2ad8-498b-8438-fac388afe1a2",
"metadata": {},
"source": [
"Press Releases Data\n",
"=\n",
"\n",
"Press Releases data powered by [Kay.ai](https://kay.ai).\n",
"\n",
">Press releases are used by companies to announce something noteworthy, including product launches, financial performance reports, partnerships, and other significant news. They are widely used by analysts to track corporate strategy, operational updates and financial performance.\n",
"Kay.ai obtains press releases of all US public companies from a variety of sources, which include the company's official press room and partnerships with various data API providers. \n",
"This data is updated till Sept 30th for free access, if you want to access the real-time feed, reach out to us at hello@kay.ai or [tweet at us](https://twitter.com/vishalrohra_)"
]
},
{
"cell_type": "markdown",
"id": "8183d85d-365f-4672-a963-52b533547de0",
"metadata": {},
"source": [
"Setup\n",
"=\n",
"\n",
"First you will need to install the `kay` package. You will also need an API key: you can get one for free at [https://kay.ai](https://kay.ai/). Once you have an API key, you must set it as an environment variable `KAY_API_KEY`.\n",
"\n",
"In this example we're going to use the `KayAiRetriever`. Take a look at the [kay notebook](/docs/integrations/retrievers/kay) for more detailed information for the parmeters that it accepts."
]
},
{
"cell_type": "markdown",
"id": "02ec21c7-49fe-4844-b58a-bf064ad40b2a",
"metadata": {},
"source": [
"Examples\n",
"="
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "bf0395f7-6ebe-4136-8b0d-00b9dea3becd",
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
" ········\n",
" ········\n"
]
}
],
"source": [
"# Setup API keys for Kay and OpenAI\n",
"from getpass import getpass\n",
"KAY_API_KEY = getpass()\n",
"OPENAI_API_KEY = getpass()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f7fcaf70-29a4-444b-8f07-9784f808c300",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"KAY_API_KEY\"] = KAY_API_KEY\n",
"os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ac00bf93-3635-4ffe-b9a6-a8b4f35c0c85",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import ConversationalRetrievalChain\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.retrievers import KayAiRetriever\n",
"\n",
"model = ChatOpenAI(model_name=\"gpt-3.5-turbo\")\n",
"retriever = KayAiRetriever.create(dataset_id=\"company\", data_types=[\"PressRelease\"], num_contexts=6)\n",
"qa = ConversationalRetrievalChain.from_llm(model, retriever=retriever)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8d9d927c-35b2-4a7b-8ea7-4d0350797941",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-> **Question**: How is the healthcare industry adopting generative AI tools? \n",
"\n",
"**Answer**: The healthcare industry is adopting generative AI tools to improve various aspects of patient care and administrative tasks. Companies like HCA Healthcare Inc, Amazon Com Inc, and Mayo Clinic have collaborated with technology providers like Google Cloud, AWS, and Microsoft to implement generative AI solutions.\n",
"\n",
"HCA Healthcare is testing a nurse handoff tool that generates draft reports quickly and accurately, which nurses have shown interest in using. They are also exploring the use of Google's medically-tuned Med-PaLM 2 LLM to support caregivers in asking complex medical questions.\n",
"\n",
"Amazon Web Services (AWS) has introduced AWS HealthScribe, a generative AI-powered service that automatically creates clinical documentation. However, integrating multiple AI systems into a cohesive solution requires significant engineering resources, including access to AI experts, healthcare data, and compute capacity.\n",
"\n",
"Mayo Clinic is among the first healthcare organizations to deploy Microsoft 365 Copilot, a generative AI service that combines large language models with organizational data from Microsoft 365. This tool has the potential to automate tasks like form-filling, relieving administrative burdens on healthcare providers and allowing them to focus more on patient care.\n",
"\n",
"Overall, the healthcare industry is recognizing the potential benefits of generative AI tools in improving efficiency, automating tasks, and enhancing patient care. \n",
"\n"
]
}
],
"source": [
"# More sample questions in the Playground on https://kay.ai\n",
"questions = [\n",
" \"How is the healthcare industry adopting generative AI tools?\",\n",
" #\"What are some recent challenges faced by the renewable energy sector?\",\n",
"]\n",
"chat_history = []\n",
"\n",
"for question in questions:\n",
" result = qa({\"question\": question, \"chat_history\": chat_history})\n",
" chat_history.append((question, result[\"answer\"]))\n",
" print(f\"-> **Question**: {question} \\n\")\n",
" print(f\"**Answer**: {result['answer']} \\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -229,7 +229,7 @@
"- fasttext (recommended)\n",
"- langdetect\n",
"\n",
"From our exprience *fasttext* performs a bit better, but you should verify it on your use case."
"From our experience *fasttext* performs a bit better, but you should verify it on your use case."
]
},
{

View File

@@ -21,7 +21,7 @@
"\n",
"In this notebook, we will look at building a basic system for question answering, based on private data. Before feeding the LLM with this data, we need to protect it so that it doesn't go to an external API (e.g. OpenAI, Anthropic). Then, after receiving the model output, we would like the data to be restored to its original form. Below you can observe an example flow of this QA system:\n",
"\n",
"<img src=\"/img/qa_privacy_protection.png\" width=\"800\"/>\n",
"<img src=\"/img/qa_privacy_protection.png\" width=\"900\"/>\n",
"\n",
"\n",
"In the following notebook, we will not go into the details of how the anonymizer works. If you are interested, please visit [this part of the documentation](https://python.langchain.com/docs/guides/privacy/presidio_data_anonymization/).\n",
@@ -839,6 +839,8 @@
"metadata": {},
"outputs": [],
"source": [
"documents = [Document(page_content=document_content)]\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)\n",
"chunks = text_splitter.split_documents(documents)\n",
"\n",

View File

@@ -0,0 +1,121 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AliCloud PAI EAS\n",
"Machine Learning Platform for AI of Alibaba Cloud is a machine learning or deep learning engineering platform intended for enterprises and developers. It provides easy-to-use, cost-effective, high-performance, and easy-to-scale plug-ins that can be applied to various industry scenarios. With over 140 built-in optimization algorithms, Machine Learning Platform for AI provides whole-process AI engineering capabilities including data labeling (PAI-iTAG), model building (PAI-Designer and PAI-DSW), model training (PAI-DLC), compilation optimization, and inference deployment (PAI-EAS). PAI-EAS supports different types of hardware resources, including CPUs and GPUs, and features high throughput and low latency. It allows you to deploy large-scale complex models with a few clicks and perform elastic scale-ins and scale-outs in real time. It also provides a comprehensive O&M and monitoring system."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Eas Service\n",
"\n",
"One who want to use eas llms must set up eas service first. When the eas service is launched, eas_service_rul and eas_service token can be got. Users can refer to https://www.alibabacloud.com/help/en/pai/user-guide/service-deployment/ for more information. Try to set environment variables to init eas service url and token:\n",
"\n",
"```base\n",
"export EAS_SERVICE_URL=XXX\n",
"export EAS_SERVICE_TOKEN=XXX\n",
"```\n",
"or run as follow codes:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from langchain.chat_models.base import HumanMessage\n",
"from langchain.chat_models import PaiEasChatEndpoint\n",
"os.environ[\"EAS_SERVICE_URL\"] = \"Your_EAS_Service_URL\"\n",
"os.environ[\"EAS_SERVICE_TOKEN\"] = \"Your_EAS_Service_Token\"\n",
"chat = PaiEasChatEndpoint(\n",
" eas_service_url=os.environ[\"EAS_SERVICE_URL\"], \n",
" eas_service_token=os.environ[\"EAS_SERVICE_TOKEN\"]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run Chat Model\n",
"You can use the default settings to call eas service as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output = chat([HumanMessage(content=\"write a funny joke\")])\n",
"print(\"output:\", output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or, call eas service with new inference params:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"kwargs = {\"temperature\": 0.8, \"top_p\": 0.8, \"top_k\": 5}\n",
"output = chat([HumanMessage(content=\"write a funny joke\")], **kwargs)\n",
"print(\"output:\", output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or, run a stream call to get a stream response:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"outputs = chat.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
"for output in outputs:\n",
" print(\"stream output:\", output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,93 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AliCloud PAI EAS\n",
"Machine Learning Platform for AI of Alibaba Cloud is a machine learning or deep learning engineering platform intended for enterprises and developers. It provides easy-to-use, cost-effective, high-performance, and easy-to-scale plug-ins that can be applied to various industry scenarios. With over 140 built-in optimization algorithms, Machine Learning Platform for AI provides whole-process AI engineering capabilities including data labeling (PAI-iTAG), model building (PAI-Designer and PAI-DSW), model training (PAI-DLC), compilation optimization, and inference deployment (PAI-EAS). PAI-EAS supports different types of hardware resources, including CPUs and GPUs, and features high throughput and low latency. It allows you to deploy large-scale complex models with a few clicks and perform elastic scale-ins and scale-outs in real time. It also provides a comprehensive O&M and monitoring system."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms.pai_eas_endpoint import PaiEasEndpoint\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One who want to use eas llms must set up eas service first. When the eas service is launched, eas_service_rul and eas_service token can be got. Users can refer to https://www.alibabacloud.com/help/en/pai/user-guide/service-deployment/ for more information,"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"EAS_SERVICE_URL\"] = \"Your_EAS_Service_URL\"\n",
"os.environ[\"EAS_SERVICE_TOKEN\"] = \"Your_EAS_Service_Token\"\n",
"llm = PaiEasEndpoint(eas_service_url=os.environ[\"EAS_SERVICE_URL\"], eas_service_token=os.environ[\"EAS_SERVICE_TOKEN\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' Thank you for asking! However, I must respectfully point out that the question contains an error. Justin Bieber was born in 1994, and the Super Bowl was first played in 1967. Therefore, it is not possible for any NFL team to have won the Super Bowl in the year Justin Bieber was born.\\n\\nI hope this clarifies things! If you have any other questions, please feel free to ask.'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n",
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"llm_chain.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,272 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Google Cloud Enterprise Search\n",
"\n",
"\n",
"[Enterprise Search](https://cloud.google.com/enterprise-search) is a part of the Generative AI App Builder suite of tools offered by Google Cloud.\n",
"\n",
"Gen AI App Builder lets developers, even those with limited machine learning skills, quickly and easily tap into the power of Googles foundation models, search expertise, and conversational AI technologies to create enterprise-grade generative AI applications. \n",
"\n",
"Enterprise Search lets organizations quickly build generative AI powered search engines for customers and employees.Enterprise Search is underpinned by a variety of Google Search technologies, including semantic search, which helps deliver more relevant results than traditional keyword-based search techniques by using natural language processing and machine learning techniques to infer relationships within the content and intent from the users query input. Enterprise Search also benefits from Googles expertise in understanding how users search and factors in content relevance to order displayed results. \n",
"\n",
"Google Cloud offers Enterprise Search via Gen App Builder in Google Cloud Console and via an API for enterprise workflow integration. \n",
"\n",
"This notebook demonstrates how to configure Enterprise Search and use the Enterprise Search retriever. The Enterprise Search retriever encapsulates the [Generative AI App Builder Python client library](https://cloud.google.com/generative-ai-app-builder/docs/libraries#client-libraries-install-python) and uses it to access the Enterprise Search [Search Service API](https://cloud.google.com/python/docs/reference/discoveryengine/latest/google.cloud.discoveryengine_v1beta.services.search_service)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install pre-requisites\n",
"\n",
"You need to install the `google-cloud-discoverengine` package to use the Enterprise Search retriever."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install google-cloud-discoveryengine"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configure access to Google Cloud and Google Cloud Enterprise Search\n",
"\n",
"Enterprise Search is generally available for the allowlist (which means customers need to be approved for access) as of June 6, 2023. Contact your Google Cloud sales team for access and pricing details. We are previewing additional features that are coming soon to the generally available offering as part of our [Trusted Tester](https://cloud.google.com/ai/earlyaccess/join?hl=en) program. Sign up for [Trusted Tester](https://cloud.google.com/ai/earlyaccess/join?hl=en) and contact your Google Cloud sales team for an expedited trial.\n",
"\n",
"Before you can run this notebook you need to:\n",
"- Set or create a Google Cloud project and turn on Gen App Builder\n",
"- Create and populate an unstructured data store\n",
"- Set credentials to access `Enterprise Search API`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set or create a Google Cloud poject and turn on Gen App Builder\n",
"\n",
"Follow the instructions in the [Enterprise Search Getting Started guide](https://cloud.google.com/generative-ai-app-builder/docs/before-you-begin) to set/create a GCP project and enable Gen App Builder.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create and populate an unstructured data store\n",
"\n",
"[Use Google Cloud Console to create an unstructured data store](https://cloud.google.com/generative-ai-app-builder/docs/create-engine-es#unstructured-data) and populate it with the example PDF documents from the `gs://cloud-samples-data/gen-app-builder/search/alphabet-investor-pdfs` Cloud Storage folder. Make sure to use the `Cloud Storage (without metadata)` option."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Set credentials to access Enterprise Search API\n",
"\n",
"The [Gen App Builder client libraries](https://cloud.google.com/generative-ai-app-builder/docs/libraries) used by the Enterprise Search retriever provide high-level language support for authenticating to Gen App Builder programmatically. Client libraries support [Application Default Credentials (ADC)](https://cloud.google.com/docs/authentication/application-default-credentials); the libraries look for credentials in a set of defined locations and use those credentials to authenticate requests to the API. With ADC, you can make credentials available to your application in a variety of environments, such as local development or production, without needing to modify your application code.\n",
"\n",
"If running in [Google Colab](https://colab.google) authenticate with `google.colab.google.auth` otherwise follow one of the [supported methods](https://cloud.google.com/docs/authentication/application-default-credentials) to make sure that you Application Default Credentials are properly set."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
" from google.colab import auth as google_auth\n",
"\n",
" google_auth.authenticate_user()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configure and use the Enterprise Search retriever\n",
"\n",
"The Enterprise Search retriever is implemented in the `langchain.retriever.GoogleCloudEntepriseSearchRetriever` class. The `get_relevant_documents` method returns a list of `langchain.schema.Document` documents where the `page_content` field of each document is populated the document content.\n",
"Depending on the data type used in Enterprise search (structured or unstructured) the `page_content` field is populated as follows:\n",
"- Structured data source: either an `extractive segment` or an `extractive answer` that matches a query. The `metadata` field is populated with metadata (if any) of the document from which the segments or answers were extracted.\n",
"- Unstructured data source: a string json containing all the fields returned from the structured data source. The `metadata` field is populated with metadata (if any) of the document \n",
"\n",
"### Only for Unstructured data sources:\n",
"An extractive answer is verbatim text that is returned with each search result. It is extracted directly from the original document. Extractive answers are typically displayed near the top of web pages to provide an end user with a brief answer that is contextually relevant to their query. Extractive answers are available for website and unstructured search.\n",
"\n",
"An extractive segment is verbatim text that is returned with each search result. An extractive segment is usually more verbose than an extractive answer. Extractive segments can be displayed as an answer to a query, and can be used to perform post-processing tasks and as input for large language models to generate answers or new text. Extractive segments are available for unstructured search.\n",
"\n",
"For more information about extractive segments and extractive answers refer to [product documentation](https://cloud.google.com/generative-ai-app-builder/docs/snippets).\n",
"\n",
"When creating an instance of the retriever you can specify a number of parameters that control which Enterprise data store to access and how a natural language query is processed, including configurations for extractive answers and segments.\n",
"\n",
"\n",
"### The mandatory parameters are:\n",
"\n",
"- `project_id` - Your Google Cloud PROJECT_ID\n",
"- `search_engine_id` - The ID of the data store you want to use. \n",
"\n",
"The `project_id` and `search_engine_id` parameters can be provided explicitly in the retriever's constructor or through the environment variables - `PROJECT_ID` and `SEARCH_ENGINE_ID`.\n",
"\n",
"You can also configure a number of optional parameters, including:\n",
"\n",
"- `max_documents` - The maximum number of documents used to provide extractive segments or extractive answers\n",
"- `get_extractive_answers` - By default, the retriever is configured to return extractive segments. Set this field to `True` to return extractive answers. This is used only when `engine_data_type` set to 0 (unstructured) \n",
"- `max_extractive_answer_count` - The maximum number of extractive answers returned in each search result.\n",
" At most 5 answers will be returned. This is used only when `engine_data_type` set to 0 (unstructured) \n",
"- `max_extractive_segment_count` - The maximum number of extractive segments returned in each search result.\n",
" Currently one segment will be returned. This is used only when `engine_data_type` set to 0 (unstructured) \n",
"- `filter` - The filter expression that allows you filter the search results based on the metadata associated with the documents in the searched data store. \n",
"- `query_expansion_condition` - Specification to determine under which conditions query expansion should occur.\n",
" 0 - Unspecified query expansion condition. In this case, server behavior defaults to disabled.\n",
" 1 - Disabled query expansion. Only the exact search query is used, even if SearchResponse.total_size is zero.\n",
" 2 - Automatic query expansion built by the Search API.\n",
"- `engine_data_type` - Defines the enterprise search data type\n",
" 0 - Unstructured data \n",
" 1 - Structured data\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and use the retriever for **unstructured** data with extractve segments "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import GoogleCloudEnterpriseSearchRetriever\n",
"\n",
"PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n",
"SEARCH_ENGINE_ID = \"<YOUR SEARCH ENGINE ID>\" # Set to your data store ID"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever = GoogleCloudEnterpriseSearchRetriever(\n",
" project_id=PROJECT_ID,\n",
" search_engine_id=SEARCH_ENGINE_ID,\n",
" max_documents=3,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"What are Alphabet's Other Bets?\"\n",
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and use the retriever for **unstructured** data with extractve answers "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever = GoogleCloudEnterpriseSearchRetriever(\n",
" project_id=PROJECT_ID,\n",
" search_engine_id=SEARCH_ENGINE_ID,\n",
" max_documents=3,\n",
" max_extractive_answer_count=3,\n",
" get_extractive_answers=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"What are Alphabet's Other Bets?\"\n",
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and use the retriever for **structured** data with extractve answers "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever = GoogleCloudEnterpriseSearchRetriever(\n",
" project_id=PROJECT_ID,\n",
" search_engine_id=SEARCH_ENGINE_ID,\n",
" max_documents=3,\n",
" engine_data_type=1\n",
")\n",
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -30,7 +30,7 @@
"metadata": {},
"outputs": [],
"source": [
"! pip install google-cloud-discoveryengine"
"! pip install google-cloud-discoveryengine\n"
]
},
{
@@ -80,7 +80,7 @@
"if \"google.colab\" in sys.modules:\n",
" from google.colab import auth as google_auth\n",
"\n",
" google_auth.authenticate_user()"
" google_auth.authenticate_user()\n"
]
},
{
@@ -90,12 +90,13 @@
"## Configure and use the Vertex AI Search retriever\n",
"\n",
"The Vertex AI Search retriever is implemented in the `langchain.retriever.GoogleVertexAISearchRetriever` class. The `get_relevant_documents` method returns a list of `langchain.schema.Document` documents where the `page_content` field of each document is populated the document content.\n",
"Depending on the data type used in Vertex AI Search (structured or unstructured) the `page_content` field is populated as follows:\n",
"Depending on the data type used in Vertex AI Search (website, structured or unstructured) the `page_content` field is populated as follows:\n",
"\n",
"- Structured data source: either an `extractive segment` or an `extractive answer` that matches a query. The `metadata` field is populated with metadata (if any) of the document from which the segments or answers were extracted.\n",
"- Unstructured data source: a string json containing all the fields returned from the structured data source. The `metadata` field is populated with metadata (if any) of the document\n",
"- Website with advanced indexing: an `extractive answer` that matches a query. The `metadata` field is populated with metadata (if any) of the document from which the segments or answers were extracted.\n",
"- Unstructured data source: either an `extractive segment` or an `extractive answer` that matches a query. The `metadata` field is populated with metadata (if any) of the document from which the segments or answers were extracted.\n",
"- Structured data source: a string json containing all the fields returned from the structured data source. The `metadata` field is populated with metadata (if any) of the document\n",
"\n",
"### Only for Unstructured data sources:\n",
"### Extractive answers & extractive segments\n",
"\n",
"An extractive answer is verbatim text that is returned with each search result. It is extracted directly from the original document. Extractive answers are typically displayed near the top of web pages to provide an end user with a brief answer that is contextually relevant to their query. Extractive answers are available for website and unstructured search.\n",
"\n",
@@ -136,6 +137,7 @@
"- `engine_data_type` - Defines the Vertex AI Search data type\n",
" - `0` - Unstructured data\n",
" - `1` - Structured data\n",
" - `2` - Website data with [Advanced Website Indexing](https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing)\n",
"\n",
"### Migration guide for `GoogleCloudEnterpriseSearchRetriever`\n",
"\n",
@@ -165,7 +167,7 @@
"\n",
"PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n",
"LOCATION_ID = \"<YOUR LOCATION>\" # Set to your data store location\n",
"DATA_STORE_ID = \"<YOUR DATA STORE ID>\" # Set to your data store ID"
"DATA_STORE_ID = \"<YOUR DATA STORE ID>\" # Set to your data store ID\n"
]
},
{
@@ -179,7 +181,7 @@
" location_id=LOCATION_ID,\n",
" data_store_id=DATA_STORE_ID,\n",
" max_documents=3,\n",
")"
")\n"
]
},
{
@@ -192,7 +194,7 @@
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
" print(doc)\n"
]
},
{
@@ -219,7 +221,7 @@
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
" print(doc)\n"
]
},
{
@@ -245,21 +247,44 @@
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
" print(doc)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and use the retrieve for multi-turn search"
"### Configure and use the retriever for **website** data with Advanced Website Indexing\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever = GoogleVertexAISearchRetriever(\n",
" project_id=PROJECT_ID,\n",
" location_id=LOCATION_ID,\n",
" data_store_id=DATA_STORE_ID,\n",
" max_documents=3,\n",
" max_extractive_answer_count=3,\n",
" get_extractive_answers=True,\n",
" engine_data_type=2,\n",
")\n",
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Search with follow-ups is [based](https://cloud.google.com/generative-ai-app-builder/docs/multi-turn-search) on generative AI models and it is different from the regular unstructured data search."
"### Configure and use the retriever for multi-turn search\n",
"\n",
"[Search with follow-ups](https://cloud.google.com/generative-ai-app-builder/docs/multi-turn-search) is based on generative AI models and it is different from the regular unstructured data search.\n"
]
},
{
@@ -276,7 +301,7 @@
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
" print(doc)\n"
]
}
],

View File

@@ -64,13 +64,13 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 3,
"id": "b4d4d386-2a6b-4942-863e-9202f5a9f1d6",
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import KayAiRetriever\n",
"import os\n",
"from langchain.retrievers import KayAiRetriever\n",
"from kay.rag.retrievers import KayRetriever\n",
"os.environ[\"KAY_API_KEY\"] = KAY_API_KEY\n",
"retriever = KayAiRetriever.create(dataset_id=\"company\", data_types=[\"10-K\", \"10-Q\", \"PressRelease\"], num_contexts=3)\n",
@@ -79,19 +79,19 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 4,
"id": "04ee2d6b-c2ab-4e15-8a8b-afaf6ef8c0f6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='Company Name: ROKU INC\\nCompany Industry: CABLE & OTHER PAY TELEVISION SERVICES\\nArticle Title: Roku and FreeWheel Announce Strategic Partnership to Bring Rokus Leading Ad Tech to FreeWheel Customers\\nText: Additionally, eMarketer Link: https://cts.businesswire.com/ct/CT?id=smartlink&url=https%3A%2F%2Fwww.insiderintelligence.com%2Finsights%2Favod-more-than-50-percent-of-us-digital-video-viewers%2F&esheet=53451144&newsitemid=20230712907788&lan=en-US&anchor=eMarketer&index=4&md5=b64dea72bcf6b6379474462602781d83 projects 57% of U.S. digital video users will stream an advertising-based video on demand (AVOD) service this year.\\nHaving solutions aimed at driving greater interoperability and automation will help accelerate this growth.\\nKey highlights of this collaboration include:\\nStreamlined Integration: Roku has now integrated its demand application programming interface (dAPI) with FreeWheel s TV platform. Roku s demand API gives publishers direct, automatic and real-time access to more advertiser demand. This enhanced integration allows for streamlined ad operation workflows and better inventory quality control, both of which will improve publisher yield and revenue.\\nSeamless Data Targeting: Publishers can now use Roku platform signals to enable advertisers to target audiences and measure campaign performance without relying on cookies. Additionally, FreeWheel and Roku will rely on data clean room technology to enable the activation of additional data sets providing better measurement and monetization to publishers and agencies.', metadata={'_additional': {'id': '962b79e0-f9d1-43ae-9f7a-8a9b42bc7a9a'}, 'chunk_type': 'text', 'chunk_years_mentioned': [], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': 'PressRelease', 'data_source_link': 'https://www.nasdaq.com/press-release/roku-and-freewheel-announce-strategic-partnership-to-bring-rokus-leading-ad-tech-to', 'data_source_publish_date': '2023-07-12T00:00:00Z', 'data_source_uid': 'a46f309c-705d-3946-96db-87aa4e73261f', 'title': 'ROKU INC | Roku and FreeWheel Announce Strategic Partnership to Bring Rokus Leading Ad Tech to FreeWheel Customers'}),\n",
" Document(page_content='Company Name: ROKU INC \\n Company Industry: CABLE & OTHER PAY TELEVISION SERVICES \\n Form Title: 10-K 2022-FY \\n Form Section: Risk Factors \\n Text: nd the Note Regarding Forward Looking Statements.This section of this Annual Report generally discusses fiscal years 2022 and 2021 and year to year comparisons between those years.Discussions of fiscal year 2020 and year to year comparisons between fiscal years 2021 and 2020 that are not included in this Annual Report can be found in Management\\'s Discussion and Analysis of Financial Condition and Results of Operations in Part II, Item 7 of our Annual Report for the fiscal year ended December 31, 2021 filed with the SEC on February 18, 2022.Overview Effective as of the fourth quarter of fiscal 2022, we reorganized our reportable segments to better align with management\\'s reporting of information reviewed by the Chief Operating Decision Maker (\"CODM\") for each segment.We renamed our \"player\" segment to \"devices\" which now includes our licensing arrangements with service operators and licensed Roku TV partners in addition to sales of our streaming players, audio products, smart home products and Roku branded TVs that will be designed, made, and sold by us in 2023.Our historical segment information is recast to conform to our new presentation in our financial statements and accompanying notes included in Item 8 of this Annual Report.Our two reportable segments are the platform segment and the devices segment.', metadata={'_additional': {'id': 'a76c5fed-5d63-45a7-b63a-2c30e05140fc'}, 'chunk_type': 'text', 'chunk_years_mentioned': [2020, 2021, 2022, 2023], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': '10-K', 'data_source_link': 'https://www.sec.gov/Archives/edgar/data/1428439/000142843923000007', 'data_source_publish_date': '2022-01-01T00:00:00Z', 'data_source_uid': '0001428439-23-000007', 'title': 'ROKU INC | 10-K 2022-FY '}),\n",
" Document(page_content='Company Name: ROKU INC \\n Company Industry: CABLE & OTHER PAY TELEVISION SERVICES \\n Form Title: 10-Q 2023-Q1 \\n Form Section: Risk Factors \\n Text: Our current and potential partners include TV brands, cable and satellite companies, and telecommunication providers.Under these license arrangements, we generally have limited or no control over the amount and timing of resources these entities dedicate to the relationship.In the past, our licensed Roku TV partners have failed to meet their forecasts and anticipated market launch dates for distributing Roku TV models, and they may fail to meet their forecasts or such launches in the future.If our licensed Roku TV partners or service operator partners fail to meet their forecasts or such launches for distributing licensed streaming devices or choose to deploy competing streaming solutions within their product lines, our business may be harmed.We depend on a small number of content publishers for a majority of our streaming hours, and if we fail to maintain these relationships, our business could be harmed.*Historically, a small number of content publishers have accounted for a significant portion of the hours streamed on our platform.In the three months ended March 31, 2023, the top three streaming services represented over 50% of all hours streamed in the period.If, for any reason, we cease distributing channels that have historically streamed a large percentage of the aggregate streaming hours on our platform, our streaming hours, our active accounts, or Roku streaming device sales may be adversely affected, and our business may be harmed.', metadata={'_additional': {'id': '2a92b2bb-02a0-4e15-8b64-d7e04078a205'}, 'chunk_type': 'text', 'chunk_years_mentioned': [2023], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': '10-Q', 'data_source_link': 'https://www.sec.gov/Archives/edgar/data/1428439/000142843923000017', 'data_source_publish_date': '2023-01-01T00:00:00Z', 'data_source_uid': '0001428439-23-000017', 'title': 'ROKU INC | 10-Q 2023-Q1 '})]"
"[Document(page_content='Company Name: ROKU INC\\nCompany Industry: CABLE & OTHER PAY TELEVISION SERVICES\\nArticle Title: Roku Is One of Fast Company\\'s Most Innovative Companies for 2023\\nText: The company launched several new devices, including the Roku Voice Remote Pro; upgraded its most premium player, the Roku Ultra; and expanded its products with a new line of smart home devices such as video doorbells, lights, and plugs integrated into the Roku ecosystem. Recently, the company announced it will launch Roku-branded TVs this spring to offer more choice and innovation to both consumers and Roku TV partners. Throughout 2022, Roku also updated its operating system (OS), the only OS purpose-built for TV, with more personalization features and enhancements across search, audio, and content discovery, launching The Buzz, Sports, and What to Watch, which provides tailored movie and TV recommendations on the Home Screen Menu. The company also released a new feature for streamers, Photo Streams, that allows customers to display and share photo albums through Roku streaming devices. Additionally, Roku unveiled Shoppable Ads, a new ad innovation that makes shopping on TV streaming as easy as it is on social media. Viewers simply press \"OK\" with their Roku remote on a shoppable ad and proceed to check out with their shipping and payment details pre-populated from Roku Pay, its proprietary payments platform. Walmart was the exclusive retailer for the launch, a first-of-its-kind partnership.', metadata={'chunk_type': 'text', 'chunk_years_mentioned': [2022, 2023], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': 'PressRelease', 'data_source_link': 'https://newsroom.roku.com/press-releases', 'data_source_publish_date': '2023-03-02T09:30:00-04:00', 'data_source_uid': '963d4a81-f58e-3093-af68-987fb1758c15', 'title': \"ROKU INC | Roku Is One of Fast Company's Most Innovative Companies for 2023\"}),\n",
" Document(page_content='Company Name: ROKU INC\\nCompany Industry: CABLE & OTHER PAY TELEVISION SERVICES\\nArticle Title: Roku Is One of Fast Company\\'s Most Innovative Companies for 2023\\nText: Finally, Roku grew its content offering with thousands of apps and watching options for users, including content on The Roku Channel, a top five app by reach and engagement on the Roku platform in the U.S. in 2022. In November, Roku released its first feature film, \"WEIRD: The Weird Al\\' Yankovic Story,\" a biopic starring Daniel Radcliffe. Throughout the year, The Roku Channel added FAST channels from NBCUniversal and the National Hockey League, as well as an exclusive AMC channel featuring its signature drama \"Mad Men.\" This year, the company announced a deal with Warner Bros. Discovery, launching new channels that will include \"Westworld\" and \"The Bachelor,\" in addition to 2,000 hours of on-demand content. Read more about Roku\\'s journey here . Fast Company\\'s Most Innovative Companies issue (March/April 2023) is available online here , as well as in-app via iTunes and on newsstands beginning March 14. About Roku, Inc.\\nRoku pioneered streaming to the TV. We connect users to the streaming content they love, enable content publishers to build and monetize large audiences, and provide advertisers with unique capabilities to engage consumers. Roku streaming players and TV-related audio devices are available in the U.S. and in select countries through direct retail sales and licensing arrangements with service operators. Roku TV models are available in the U.S. and select countries through licensing arrangements with TV OEM brands.', metadata={'chunk_type': 'text', 'chunk_years_mentioned': [2022, 2023], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': 'PressRelease', 'data_source_link': 'https://newsroom.roku.com/press-releases', 'data_source_publish_date': '2023-03-02T09:30:00-04:00', 'data_source_uid': '963d4a81-f58e-3093-af68-987fb1758c15', 'title': \"ROKU INC | Roku Is One of Fast Company's Most Innovative Companies for 2023\"}),\n",
" Document(page_content='Company Name: ROKU INC\\nCompany Industry: CABLE & OTHER PAY TELEVISION SERVICES\\nArticle Title: Roku\\'s New NFL Zone Gives Fans Easy Access to NFL Games Right On Time for 2023 Season\\nText: In partnership with the NFL, the new NFL Zone offers viewers an easy way to find where to watch NFL live games Today, Roku (NASDAQ: ROKU ) and the National Football League (NFL) announced the recently launched NFL Zone within the Roku Sports experience to kick off the 2023 NFL season. This strategic partnership between Roku and the NFL marks the first official league-branded zone within Roku\\'s Sports experience. Available now, the NFL Zone offers football fans a centralized location to find live and upcoming games, so they can spend less time figuring out where to watch the game and more time rooting for their favorite teams. Users can also tune in for weekly game previews, League highlights, and additional NFL content, all within the zone. This press release features multimedia. View the full release here: In partnership with the NFL, Roku\\'s new NFL Zone offers viewers an easy way to find where to watch NFL live games (Photo: Business Wire) \"Last year we introduced the Sports experience for our highly engaged sports audience, making it simpler for Roku users to watch sports programming,\" said Gidon Katz, President, Consumer Experience, at Roku. \"As we start the biggest sports season of the year, providing easy access to NFL games and content to our millions of users is a top priority for us. We look forward to fans immersing themselves within the NFL Zone and making it their destination to find NFL games.', metadata={'chunk_type': 'text', 'chunk_years_mentioned': [2023], 'company_name': 'ROKU INC', 'company_sic_code_description': 'CABLE & OTHER PAY TELEVISION SERVICES', 'data_source': 'PressRelease', 'data_source_link': 'https://newsroom.roku.com/press-releases', 'data_source_publish_date': '2023-09-12T09:00:00-04:00', 'data_source_uid': '963d4a81-f58e-3093-af68-987fb1758c15', 'title': \"ROKU INC | Roku's New NFL Zone Gives Fans Easy Access to NFL Games Right On Time for 2023 Season\"})]"
]
},
"execution_count": 21,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}

View File

@@ -28,19 +28,29 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 11,
"id": "63a8af5b",
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: You are using pip version 22.0.4; however, version 23.3 is available.\n",
"You should consider upgrading via the '/Users/joe/projects/elastic/langchain/libs/langchain/.venv/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"#!pip install lark elasticsearch"
"#!pip install -qU lark elasticsearch"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "cb4a5787",
"metadata": {
"tags": []
@@ -60,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 2,
"id": "bcbe04d9",
"metadata": {
"tags": []
@@ -115,7 +125,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"id": "86e34dbf",
"metadata": {
"tags": []
@@ -164,17 +174,10 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 4,
"id": "38a126e9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query='dinosaur' filter=None limit=None\n"
]
},
{
"data": {
"text/plain": [
@@ -184,7 +187,7 @@
" Document(page_content='A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea', metadata={'year': 2006, 'director': 'Satoshi Kon', 'rating': 8.6})]"
]
},
"execution_count": 10,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -196,24 +199,17 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 5,
"id": "b19d4da0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig') limit=None\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='A bunch of normal-sized women are supremely wholesome and some men pine after them', metadata={'year': 2019, 'director': 'Greta Gerwig', 'rating': 8.3})]"
]
},
"execution_count": 11,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -237,7 +233,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 6,
"id": "bff36b88-b506-4877-9c63-e5a1a8d78e64",
"metadata": {
"tags": []
@@ -256,19 +252,12 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 7,
"id": "2758d229-4f97-499c-819f-888acaf8ee10",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query='dinosaur' filter=None limit=2\n"
]
},
{
"data": {
"text/plain": [
@@ -276,7 +265,7 @@
" Document(page_content='Toys come alive and have a blast doing so', metadata={'year': 1995, 'genre': 'animated'})]"
]
},
"execution_count": 13,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -297,24 +286,17 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 8,
"id": "e460da93",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query='animated toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Operation(operator=<Operator.OR: 'or'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated'), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='comedy')]), Comparison(comparator=<Comparator.GTE: 'gte'>, attribute='year', value=1990)]) limit=None\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='Toys come alive and have a blast doing so', metadata={'year': 1995, 'genre': 'animated'})]"
]
},
"execution_count": 18,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -325,21 +307,10 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "0851fc42",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ObjectApiResponse({'acknowledged': True})"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"vectorstore.client.indices.delete(index=\"elasticsearch-self-query-demo\")"
]
@@ -361,7 +332,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.3"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,120 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ab66dd43",
"metadata": {},
"source": [
"# SingleStoreDB\n",
"\n",
">[SingleStoreDB](https://singlestore.com/) is a high-performance distributed SQL database that supports deployment both in the [cloud](https://www.singlestore.com/cloud/) and on-premises. It provides vector storage, and vector functions including [dot_product](https://docs.singlestore.com/managed-service/en/reference/sql-reference/vector-functions/dot_product.html) and [euclidean_distance](https://docs.singlestore.com/managed-service/en/reference/sql-reference/vector-functions/euclidean_distance.html), thereby supporting AI applications that require text similarity matching. \n",
"\n",
"\n",
"This notebook shows how to use a retriever that uses `SingleStoreDB`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51b49135-a61a-49e8-869d-7c1d76794cd7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Establishing a connection to the database is facilitated through the singlestoredb Python connector.\n",
"# Please ensure that this connector is installed in your working environment.\n",
"!pip install singlestoredb"
]
},
{
"cell_type": "markdown",
"id": "aaf80e7f",
"metadata": {},
"source": [
"## Create Retriever from vector store"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcb3c8c2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"\n",
"# We want to use OpenAIEmbeddings so we have to get the OpenAI API Key.\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n",
"\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import SingleStoreDB\n",
"from langchain.document_loaders import TextLoader\n",
"\n",
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"# Setup connection url as environment variable\n",
"os.environ[\"SINGLESTOREDB_URL\"] = \"root:pass@localhost:3306/db\"\n",
"\n",
"# Load documents to the store\n",
"docsearch = SingleStoreDB.from_documents(\n",
" docs,\n",
" embeddings,\n",
" table_name=\"notebook\", # use table with a custom name\n",
")\n",
"\n",
"# create retriever from the vector store\n",
"retriever = docsearch.as_retriever(search_kwargs={\"k\": 2})"
]
},
{
"cell_type": "markdown",
"id": "fc0915db",
"metadata": {},
"source": [
"## Search with retriever"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "b605284d",
"metadata": {},
"outputs": [],
"source": [
"result = retriever.get_relevant_documents(\"What did the president say about Ketanji Brown Jackson\")\n",
"print(docs[0].page_content)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -139,7 +139,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"id": "67ab8afa-f7c6-4fbf-b596-cb512da949da",
"metadata": {
"id": "67ab8afa-f7c6-4fbf-b596-cb512da949da",
@@ -172,7 +172,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "aac9563e",
"metadata": {
"id": "aac9563e",
@@ -186,7 +186,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "a3c3999a",
"metadata": {
"id": "a3c3999a",
@@ -207,7 +207,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"id": "12eb86d8",
"metadata": {
"id": "12eb86d8",
@@ -218,7 +218,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[Document(page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', metadata={'source': '../../modules/state_of_the_union.txt'}), Document(page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', metadata={'source': '../../modules/state_of_the_union.txt', 'date': '2016-01-01', 'rating': 2, 'author': 'John Doe'}), Document(page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', metadata={'source': '../../modules/state_of_the_union.txt', 'date': '2010-01-01', 'rating': 1, 'author': 'John Doe'}), Document(page_content='As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential. \\n\\nWhile it often appears that we never agree, that isnt true. I signed 80 bipartisan bills into law last year. From preventing government shutdowns to protecting Asian-Americans from still-too-common hate crimes to reforming military justice.', metadata={'source': '../../modules/state_of_the_union.txt'})]\n"
"[Document(page_content='One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', metadata={'source': '../../modules/state_of_the_union.txt'}), Document(page_content='As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential. \\n\\nWhile it often appears that we never agree, that isnt true. I signed 80 bipartisan bills into law last year. From preventing government shutdowns to protecting Asian-Americans from still-too-common hate crimes to reforming military justice.', metadata={'source': '../../modules/state_of_the_union.txt'}), Document(page_content='A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since shes been nominated, shes received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \\n\\nAnd if we are to advance liberty and justice, we need to secure the Border and fix the immigration system.', metadata={'source': '../../modules/state_of_the_union.txt'}), Document(page_content='This is personal to me and Jill, to Kamala, and to so many of you. \\n\\nCancer is the #2 cause of death in Americasecond only to heart disease. \\n\\nLast month, I announced our plan to supercharge \\nthe Cancer Moonshot that President Obama asked me to lead six years ago. \\n\\nOur goal is to cut the cancer death rate by at least 50% over the next 25 years, turn more cancers from death sentences into treatable diseases. \\n\\nMore support for patients and families.', metadata={'source': '../../modules/state_of_the_union.txt'})]\n"
]
}
],
@@ -247,7 +247,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 5,
"id": "5d076412",
"metadata": {},
"outputs": [
@@ -284,12 +284,13 @@
"## Filtering Metadata\n",
"With metadata added to the documents, you can add metadata filtering at query time. \n",
"\n",
"### Example: Filter by keyword"
"### Example: Filter by Exact keyword\n",
"Notice: We are using the keyword subfield thats not analyzed"
]
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 6,
"id": "b2a4bd1b",
"metadata": {},
"outputs": [
@@ -297,12 +298,42 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'source': '../../modules/state_of_the_union.txt', 'date': '2010-01-01', 'rating': 1, 'author': 'John Doe', 'geo_location': {'lat': 40.12, 'lon': -71.34}}\n"
"{'source': '../../modules/state_of_the_union.txt', 'date': '2016-01-01', 'rating': 2, 'author': 'John Doe'}\n"
]
}
],
"source": [
"docs = db.similarity_search(query, filter=[{ \"match\": { \"metadata.author\": \"John Doe\"}}])\n",
"docs = db.similarity_search(query, filter=[{ \"term\": { \"metadata.author.keyword\": \"John Doe\"}}])\n",
"print(docs[0].metadata)"
]
},
{
"cell_type": "markdown",
"id": "1898ab77",
"metadata": {},
"source": [
"### Example: Filter by Partial Match\n",
"This example shows how to filter by partial match. This is useful when you don't know the exact value of the metadata field. For example, if you want to filter by the metadata field `author` and you don't know the exact value of the author, you can use a partial match to filter by the author's last name. Fuzzy matching is also supported.\n",
"\n",
"\"Jon\" matches on \"John Doe\" as \"Jon\" is a close match to \"John\" token."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f3d294ff",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'source': '../../modules/state_of_the_union.txt', 'date': '2016-01-01', 'rating': 2, 'author': 'John Doe'}\n"
]
}
],
"source": [
"docs = db.similarity_search(query, filter=[{ \"match\": { \"metadata.author\": { \"query\": \"Jon\", \"fuzziness\": \"AUTO\" } }}])\n",
"print(docs[0].metadata)"
]
},

View File

@@ -4,9 +4,8 @@ This output parser can be used when you want to return a list of comma-separated
```python
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
output_parser = CommaSeparatedListOutputParser()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 150 KiB

After

Width:  |  Height:  |  Size: 185 KiB

View File

@@ -3848,8 +3848,8 @@
"destination": "/docs/additional_resources/dependents"
},
{
"source": "docs/integrations/retrievers/google_cloud_enterprise_search",
"destination": "docs/integrations/retrievers/google_vertex_ai_search"
"source": "/docs/integrations/retrievers/google_cloud_enterprise_search",
"destination": "/docs/integrations/retrievers/google_vertex_ai_search"
}
]
}

View File

@@ -1,12 +1,14 @@
import importlib.metadata
import logging
import os
import traceback
import warnings
from contextvars import ContextVar
from datetime import datetime
from typing import Any, Dict, List, Literal, Union
from uuid import UUID
import requests
from packaging.version import parse
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.agent import AgentAction, AgentFinish
@@ -142,7 +144,7 @@ def _get_user_props(metadata: Any) -> Any:
return user_props_ctx.get()
metadata = metadata or {}
return metadata.get("user_props")
return metadata.get("user_props", None)
def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]:
@@ -191,6 +193,8 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
__api_url: str
__app_id: str
__verbose: bool
__llmonitor_version: str
__has_valid_config: bool
def __init__(
self,
@@ -200,37 +204,58 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
) -> None:
super().__init__()
self.__api_url = api_url or os.getenv("LLMONITOR_API_URL") or DEFAULT_API_URL
self.__has_valid_config = True
try:
import llmonitor
self.__llmonitor_version = importlib.metadata.version("llmonitor")
self.__track_event = llmonitor.track_event
except ImportError:
warnings.warn(
"""[LLMonitor] To use the LLMonitor callback handler you need to
have the `llmonitor` Python package installed. Please install it
with `pip install llmonitor`"""
)
self.__has_valid_config = False
if parse(self.__llmonitor_version) < parse("0.0.20"):
warnings.warn(
f"""[LLMonitor] The installed `llmonitor` version is
{self.__llmonitor_version} but `LLMonitorCallbackHandler` requires
at least version 0.0.20 upgrade `llmonitor` with `pip install
--upgrade llmonitor`"""
)
self.__has_valid_config = False
self.__has_valid_config = True
self.__api_url = api_url or os.getenv("LLMONITOR_API_URL") or DEFAULT_API_URL
self.__verbose = verbose or bool(os.getenv("LLMONITOR_VERBOSE"))
_app_id = app_id or os.getenv("LLMONITOR_APP_ID")
if _app_id is None:
raise ValueError(
"""app_id must be provided either as an argument or as
warnings.warn(
"""[LLMonitor] app_id must be provided either as an argument or as
an environment variable"""
)
self.__app_id = _app_id
self.__has_valid_config = False
else:
self.__app_id = _app_id
if self.__has_valid_config is False:
return None
try:
res = requests.get(f"{self.__api_url}/api/app/{self.__app_id}")
if not res.ok:
raise ConnectionError()
except Exception as e:
raise ConnectionError(
f"Could not connect to the LLMonitor API at {self.__api_url}"
) from e
def __send_event(self, event: Dict[str, Any]) -> None:
headers = {"Content-Type": "application/json"}
event = {**event, "app": self.__app_id, "timestamp": str(datetime.utcnow())}
if self.__verbose:
print("llmonitor_callback", event)
data = {"events": event}
requests.post(headers=headers, url=f"{self.__api_url}/api/report", json=data)
except Exception:
warnings.warn(
f"""[LLMonitor] Could not connect to the LLMonitor API at
{self.__api_url}"""
)
def on_llm_start(
self,
@@ -243,27 +268,28 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
name = kwargs.get("invocation_params", {}).get("model_name")
input = _parse_input(prompts)
event = {
"event": "start",
"type": "llm",
"userId": user_id,
"runId": str(run_id),
"parentRunId": str(parent_run_id) if parent_run_id else None,
"input": _parse_input(prompts),
"name": kwargs.get("invocation_params", {}).get("model_name"),
"tags": tags,
"metadata": metadata,
}
if user_props:
event["userProps"] = user_props
self.__send_event(event)
self.__track_event(
"llm",
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
tags=tags,
metadata=metadata,
user_props=user_props,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_llm_start: {e}")
warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}")
def on_chat_model_start(
self,
@@ -276,28 +302,29 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
name = kwargs.get("invocation_params", {}).get("model_name")
input = _parse_lc_messages(messages[0])
event = {
"event": "start",
"type": "llm",
"userId": user_id,
"runId": str(run_id),
"parentRunId": str(parent_run_id) if parent_run_id else None,
"input": _parse_lc_messages(messages[0]),
"name": kwargs.get("invocation_params", {}).get("model_name"),
"tags": tags,
"metadata": metadata,
}
if user_props:
event["userProps"] = user_props
self.__send_event(event)
self.__track_event(
"llm",
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
tags=tags,
metadata=metadata,
user_props=user_props,
)
except Exception as e:
logging.warning(
f"[LLMonitor] An error occurred in on_chat_model_start: " f"{e}"
f"[LLMonitor] An error occurred in on_chat_model_start: {e}"
)
def on_llm_end(
@@ -308,9 +335,11 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
token_usage = (response.llm_output or {}).get("token_usage", {})
parsed_output = [
{
"text": generation.text,
@@ -330,20 +359,19 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
for generation in response.generations[0]
]
event = {
"event": "end",
"type": "llm",
"runId": str(run_id),
"parent_run_id": str(parent_run_id) if parent_run_id else None,
"output": parsed_output,
"tokensUsage": {
self.__track_event(
"llm",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=parsed_output,
token_usage={
"prompt": token_usage.get("prompt_tokens"),
"completion": token_usage.get("completion_tokens"),
},
}
self.__send_event(event)
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_llm_end: {e}")
warnings.warn(f"[LLMonitor] An error occurred in on_llm_end: {e}")
def on_tool_start(
self,
@@ -356,27 +384,27 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
name = serialized.get("name")
event = {
"event": "start",
"type": "tool",
"userId": user_id,
"runId": str(run_id),
"parentRunId": str(parent_run_id) if parent_run_id else None,
"name": serialized.get("name"),
"input": input_str,
"tags": tags,
"metadata": metadata,
}
if user_props:
event["userProps"] = user_props
self.__send_event(event)
self.__track_event(
"tool",
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input_str,
tags=tags,
metadata=metadata,
user_props=user_props,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_tool_start: {e}")
warnings.warn(f"[LLMonitor] An error occurred in on_tool_start: {e}")
def on_tool_end(
self,
@@ -387,17 +415,18 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
tags: Union[List[str], None] = None,
**kwargs: Any,
) -> None:
if self.__has_valid_config is False:
return
try:
event = {
"event": "end",
"type": "tool",
"runId": str(run_id),
"parent_run_id": str(parent_run_id) if parent_run_id else None,
"output": output,
}
self.__send_event(event)
self.__track_event(
"tool",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_tool_end: {e}")
warnings.warn(f"[LLMonitor] An error occurred in on_tool_end: {e}")
def on_chain_start(
self,
@@ -410,6 +439,8 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
name = serialized.get("id", [None, None, None, None])[3]
type = "chain"
@@ -419,35 +450,32 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
if agentName is None:
agentName = metadata.get("agentName")
if name == "AgentExecutor" or name == "PlanAndExecute":
type = "agent"
if agentName is not None:
type = "agent"
name = agentName
if name == "AgentExecutor" or name == "PlanAndExecute":
type = "agent"
if parent_run_id is not None:
type = "chain"
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
input = _parse_input(inputs)
event = {
"event": "start",
"type": type,
"userId": user_id,
"runId": str(run_id),
"parentRunId": str(parent_run_id) if parent_run_id else None,
"input": _parse_input(inputs),
"tags": tags,
"metadata": metadata,
"name": name,
}
if user_props:
event["userProps"] = user_props
self.__send_event(event)
self.__track_event(
type,
"start",
user_id=user_id,
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
tags=tags,
metadata=metadata,
user_props=user_props,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_chain_start: {e}")
warnings.warn(f"[LLMonitor] An error occurred in on_chain_start: {e}")
def on_chain_end(
self,
@@ -457,14 +485,18 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
event = {
"event": "end",
"type": "chain",
"runId": str(run_id),
"output": _parse_output(outputs),
}
self.__send_event(event)
output = _parse_output(outputs)
self.__track_event(
"chain",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_chain_end: {e}")
@@ -476,16 +508,20 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
event = {
"event": "start",
"type": "tool",
"runId": str(run_id),
"parentRunId": str(parent_run_id) if parent_run_id else None,
"name": action.tool,
"input": _parse_input(action.tool_input),
}
self.__send_event(event)
name = action.tool
input = _parse_input(action.tool_input)
self.__track_event(
"tool",
"start",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_agent_action: {e}")
@@ -497,15 +533,18 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
event = {
"event": "end",
"type": "agent",
"runId": str(run_id),
"parentRunId": str(parent_run_id) if parent_run_id else None,
"output": _parse_output(finish.return_values),
}
self.__send_event(event)
output = _parse_output(finish.return_values)
self.__track_event(
"agent",
"end",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_agent_finish: {e}")
@@ -517,15 +556,16 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
event = {
"event": "error",
"type": "chain",
"runId": str(run_id),
"parent_run_id": str(parent_run_id) if parent_run_id else None,
"error": {"message": str(error), "stack": traceback.format_exc()},
}
self.__send_event(event)
self.__track_event(
"chain",
"error",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_chain_error: {e}")
@@ -537,15 +577,16 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
event = {
"event": "error",
"type": "tool",
"runId": str(run_id),
"parent_run_id": str(parent_run_id) if parent_run_id else None,
"error": {"message": str(error), "stack": traceback.format_exc()},
}
self.__send_event(event)
self.__track_event(
"tool",
"error",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_tool_error: {e}")
@@ -557,15 +598,16 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
parent_run_id: Union[UUID, None] = None,
**kwargs: Any,
) -> Any:
if self.__has_valid_config is False:
return
try:
event = {
"event": "error",
"type": "llm",
"runId": str(run_id),
"parent_run_id": str(parent_run_id) if parent_run_id else None,
"error": {"message": str(error), "stack": traceback.format_exc()},
}
self.__send_event(event)
self.__track_event(
"llm",
"error",
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_llm_error: {e}")

View File

@@ -22,7 +22,7 @@ from langchain.utilities.openapi import OpenAPISpec
from langchain.utils.input import get_colored_text
if TYPE_CHECKING:
from openapi_schema_pydantic import Parameter
from openapi_pydantic import Parameter
def _get_description(o: Any, prefer_short: bool) -> Optional[str]:

View File

@@ -38,6 +38,7 @@ from langchain.chat_models.minimax import MiniMaxChat
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
from langchain.chat_models.ollama import ChatOllama
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain.chat_models.vertexai import ChatVertexAI
from langchain.chat_models.yandex import ChatYandexGPT
@@ -63,6 +64,7 @@ __all__ = [
"ErnieBotChat",
"ChatJavelinAIGateway",
"ChatKonko",
"PaiEasChatEndpoint",
"QianfanChatEndpoint",
"ChatFireworks",
"ChatYandexGPT",

View File

@@ -11,7 +11,6 @@ from typing import (
List,
Optional,
Sequence,
Union,
cast,
)
@@ -38,12 +37,10 @@ from langchain.schema import (
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import (
AIMessage,
AnyMessage,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
@@ -79,7 +76,7 @@ async def _agenerate_from_stream(
return ChatResult(generations=[generation])
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models."""
cache: Optional[bool] = None
@@ -116,9 +113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
@property
def OutputType(self) -> Any:
"""Get the output type for this runnable."""
return Union[
HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage
]
return AnyMessage
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
@@ -140,23 +135,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
) -> BaseMessage:
config = config or {}
return cast(
BaseMessageChunk,
cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
).generations[0][0],
).message,
)
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
).generations[0][0],
).message
async def ainvoke(
self,
@@ -165,7 +157,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
) -> BaseMessage:
config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)],
@@ -176,9 +168,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
run_name=config.get("run_name"),
**kwargs,
)
return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
)
return cast(ChatGeneration, llm_result.generations[0][0]).message
def stream(
self,
@@ -190,7 +180,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
@@ -241,7 +233,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream == BaseChatModel._astream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()

View File

@@ -0,0 +1,324 @@
import asyncio
import json
import logging
from functools import partial
from typing import Any, AsyncIterator, Dict, List, Optional
import requests
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import root_validator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class PaiEasChatEndpoint(BaseChatModel):
"""Eas LLM Service chat model API.
To use, must have a deployed eas chat llm service on AliCloud. One can set the
environment variable ``eas_service_url`` and ``eas_service_token`` set with your eas
service url and service token.
Example:
.. code-block:: python
from langchain.chat_models import PaiEasChatEndpoint
eas_chat_endpoint = PaiEasChatEndpoint(
eas_service_url="your_service_url",
eas_service_token="your_service_token"
)
"""
"""PAI-EAS Service URL"""
eas_service_url: str
"""PAI-EAS Service TOKEN"""
eas_service_token: str
"""PAI-EAS Service Infer Params"""
max_new_tokens: Optional[int] = 512
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.1
top_k: Optional[int] = 10
do_sample: Optional[bool] = False
use_cache: Optional[bool] = True
stop_sequences: Optional[List[str]] = None
"""Enable stream chat mode."""
streaming: bool = False
"""Key/value arguments to pass to the model. Reserved for future use"""
model_kwargs: Optional[dict] = None
version: Optional[str] = "2.0"
timeout: Optional[int] = 5000
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["eas_service_url"] = get_from_dict_or_env(
values, "eas_service_url", "EAS_SERVICE_URL"
)
values["eas_service_token"] = get_from_dict_or_env(
values, "eas_service_token", "EAS_SERVICE_TOKEN"
)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
"eas_service_url": self.eas_service_url,
"eas_service_token": self.eas_service_token,
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "pai_eas_chat_endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
return {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"stop_sequences": [],
"do_sample": self.do_sample,
"use_cache": self.use_cache,
}
def _invocation_params(
self, stop_sequences: Optional[List[str]], **kwargs: Any
) -> dict:
params = self._default_params
if self.model_kwargs:
params.update(self.model_kwargs)
if self.stop_sequences is not None and stop_sequences is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop_sequences is not None:
params["stop"] = self.stop_sequences
else:
params["stop"] = stop_sequences
return {**params, **kwargs}
def format_request_payload(
self, messages: List[BaseMessage], **model_kwargs: Any
) -> dict:
prompt: Dict[str, Any] = {}
user_content: List[str] = []
assistant_content: List[str] = []
for message in messages:
"""Converts message to a dict according to role"""
if isinstance(message, HumanMessage):
user_content = user_content + [message.content]
elif isinstance(message, AIMessage):
assistant_content = assistant_content + [message.content]
elif isinstance(message, SystemMessage):
prompt["system_prompt"] = message.content
elif isinstance(message, ChatMessage) and message.role in [
"user",
"assistant",
"system",
]:
if message.role == "system":
prompt["system_prompt"] = message.content
elif message.role == "user":
user_content = user_content + [message.content]
elif message.role == "assistant":
assistant_content = assistant_content + [message.content]
else:
supported = ",".join([role for role in ["user", "assistant", "system"]])
raise ValueError(
f"""Received unsupported role.
Supported roles for the LLaMa Foundation Model: {supported}"""
)
prompt["prompt"] = user_content[len(user_content) - 1]
history = [
history_item
for _, history_item in enumerate(zip(user_content[:-1], assistant_content))
]
prompt["history"] = history
return {**prompt, **model_kwargs}
def _format_response_payload(
self, output: bytes, stop_sequences: Optional[List[str]]
) -> str:
"""Formats response"""
try:
text = json.loads(output)["response"]
if stop_sequences:
text = enforce_stop_tokens(text, stop_sequences)
return text
except Exception as e:
if isinstance(e, json.decoder.JSONDecodeError):
return output.decode("utf-8")
raise e
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
params = self._invocation_params(stop, **kwargs)
request_payload = self.format_request_payload(messages, **params)
response_payload = self._call_eas(request_payload)
generated_text = self._format_response_payload(response_payload, params["stop"])
if run_manager:
run_manager.on_llm_new_token(generated_text)
return generated_text
def _call_eas(self, query_body: dict) -> Any:
"""Generate text from the eas service."""
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"{self.eas_service_token}",
}
# make request
response = requests.post(
self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout
)
if response.status_code != 200:
raise Exception(
f"Request failed with status code {response.status_code}"
f" and message {response.text}"
)
return response.text
def _call_eas_stream(self, query_body: dict) -> Any:
"""Generate text from the eas service."""
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"{self.eas_service_token}",
}
# make request
response = requests.post(
self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout
)
if response.status_code != 200:
raise Exception(
f"Request failed with status code {response.status_code}"
f" and message {response.text}"
)
return response
def _convert_chunk_to_message_message(
self,
chunk: str,
) -> AIMessageChunk:
data = json.loads(chunk.encode("utf-8"))
return AIMessageChunk(content=data.get("response", ""))
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._invocation_params(stop, **kwargs)
request_payload = self.format_request_payload(messages, **params)
request_payload["use_stream_chat"] = True
response = self._call_eas_stream(request_payload)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
):
if chunk:
content = self._convert_chunk_to_message_message(chunk)
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in params["stop"]:
if stop_seq in content.content:
stop_seq_found = stop_seq
# identify text to yield
text: Optional[str] = None
if stop_seq_found:
content.content = content.content[
: content.content.index(stop_seq_found)
]
# yield text, if any
if text:
if run_manager:
await run_manager.on_llm_new_token(content.content)
yield ChatGenerationChunk(message=content)
# break if stop sequence found
if stop_seq_found:
break
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
if stream if stream is not None else self.streaming:
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
generation = chunk
assert generation is not None
return ChatResult(generations=[generation])
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

View File

@@ -342,6 +342,12 @@ def _import_openlm() -> Any:
return OpenLM
def _import_pai_eas_endpoint() -> Any:
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
return PaiEasEndpoint
def _import_petals() -> Any:
from langchain.llms.petals import Petals
@@ -593,6 +599,8 @@ def __getattr__(name: str) -> Any:
return _import_openllm()
elif name == "OpenLM":
return _import_openlm()
elif name == "PaiEasEndpoint":
return _import_pai_eas_endpoint()
elif name == "Petals":
return _import_petals()
elif name == "PipelineAI":
@@ -703,6 +711,7 @@ __all__ = [
"OpenAIChat",
"OpenLLM",
"OpenLM",
"PaiEasEndpoint",
"Petals",
"PipelineAI",
"Predibase",
@@ -780,6 +789,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"ollama": _import_ollama,
"openai": _import_openai,
"openlm": _import_openlm,
"pai_eas_endpoint": _import_pai_eas_endpoint,
"petals": _import_petals,
"pipelineai": _import_pipelineai,
"predibase": _import_predibase,

View File

@@ -6,9 +6,7 @@ from langchain.callbacks.manager import (
)
from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.output import GenerationChunk
from langchain.schema.runnable.config import RunnableConfig
from langchain.utils.env import get_from_dict_or_env
@@ -140,42 +138,6 @@ class Fireworks(LLM):
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
prompt = self._convert_input(input).to_string()
generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompt):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
prompt = self._convert_input(input).to_string()
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(prompt):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
def completion_with_retry(
llm: Fireworks,

View File

@@ -91,7 +91,7 @@ class Modal(LLM):
if prompt in response.json()["prompt"]:
response_json = response.json()
except KeyError:
raise ValueError("LangChain requires 'prompt' key in response.")
raise KeyError("LangChain requires 'prompt' key in response.")
text = response_json["prompt"]
if stop is not None:
# I believe this is required since the stop tokens

View File

@@ -0,0 +1,240 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import root_validator
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class PaiEasEndpoint(LLM):
"""Langchain LLM class to help to access eass llm service.
To use this endpoint, must have a deployed eas chat llm service on PAI AliCloud.
One can set the environment variable ``eas_service_url`` and ``eas_service_token``.
The environment variables can set with your eas service url and service token.
Example:
.. code-block:: python
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
eas_chat_endpoint = PaiEasChatEndpoint(
eas_service_url="your_service_url",
eas_service_token="your_service_token"
)
"""
"""PAI-EAS Service URL"""
eas_service_url: str
"""PAI-EAS Service TOKEN"""
eas_service_token: str
"""PAI-EAS Service Infer Params"""
max_new_tokens: Optional[int] = 512
temperature: Optional[float] = 0.95
top_p: Optional[float] = 0.1
top_k: Optional[int] = 0
stop_sequences: Optional[List[str]] = None
"""Enable stream chat mode."""
streaming: bool = False
"""Key/value arguments to pass to the model. Reserved for future use"""
model_kwargs: Optional[dict] = None
version: Optional[str] = "2.0"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["eas_service_url"] = get_from_dict_or_env(
values, "eas_service_url", "EAS_SERVICE_URL"
)
values["eas_service_token"] = get_from_dict_or_env(
values, "eas_service_token", "EAS_SERVICE_TOKEN"
)
return values
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "pai_eas_endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
return {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"stop_sequences": [],
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
"eas_service_url": self.eas_service_url,
"eas_service_token": self.eas_service_token,
**_model_kwargs,
}
def _invocation_params(
self, stop_sequences: Optional[List[str]], **kwargs: Any
) -> dict:
params = self._default_params
if self.stop_sequences is not None and stop_sequences is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop_sequences is not None:
params["stop"] = self.stop_sequences
else:
params["stop"] = stop_sequences
if self.model_kwargs:
params.update(self.model_kwargs)
return {**params, **kwargs}
@staticmethod
def _process_response(
response: Any, stop: Optional[List[str]], version: Optional[str]
) -> str:
if version == "1.0":
text = response
else:
text = response["response"]
if stop:
text = enforce_stop_tokens(text, stop)
return "".join(text)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
params = self._invocation_params(stop, **kwargs)
prompt = prompt.strip()
response = None
try:
if self.streaming:
completion = ""
for chunk in self._stream(prompt, stop, run_manager, **params):
completion += chunk.text
return completion
else:
response = self._call_eas(prompt, params)
_stop = params.get("stop")
return self._process_response(response, _stop, self.version)
except Exception as error:
raise ValueError(f"Error raised by the service: {error}")
def _call_eas(self, prompt: str = "", params: Dict = {}) -> Any:
"""Generate text from the eas service."""
headers = {
"Content-Type": "application/json",
"Authorization": f"{self.eas_service_token}",
}
if self.version == "1.0":
body = {
"input_ids": f"{prompt}",
}
else:
body = {
"prompt": f"{prompt}",
}
# add params to body
for key, value in params.items():
body[key] = value
# make request
response = requests.post(self.eas_service_url, headers=headers, json=body)
if response.status_code != 200:
raise Exception(
f"Request failed with status code {response.status_code}"
f" and message {response.text}"
)
try:
return json.loads(response.text)
except Exception as e:
if isinstance(e, json.decoder.JSONDecodeError):
return response.text
raise e
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
invocation_params = self._invocation_params(stop, **kwargs)
headers = {
"User-Agent": "Test Client",
"Authorization": f"{self.eas_service_token}",
}
if self.version == "1.0":
pload = {"input_ids": prompt, **invocation_params}
response = requests.post(
self.eas_service_url, headers=headers, json=pload, stream=True
)
res = GenerationChunk(text=response.text)
if run_manager:
run_manager.on_llm_new_token(res.text)
# yield text, if any
yield res
else:
pload = {"prompt": prompt, "use_stream_chat": "True", **invocation_params}
response = requests.post(
self.eas_service_url, headers=headers, json=pload, stream=True
)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["response"]
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop"]:
if stop_seq in output:
stop_seq_found = stop_seq
# identify text to yield
text: Optional[str] = None
if stop_seq_found:
text = output[: output.index(stop_seq_found)]
else:
text = output
# yield text, if any
if text:
res = GenerationChunk(text=text)
yield res
if run_manager:
run_manager.on_llm_new_token(res.text)
# break if stop sequence found
if stop_seq_found:
break

View File

@@ -1,6 +1,9 @@
from typing import Any, Dict, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.load.serializable import Serializable
@@ -128,3 +131,75 @@ class YandexGPT(_BaseYandexGPT, LLM):
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Async call the Yandex GPT model and return the output.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
"""
try:
import asyncio
import grpc
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import (
InstructRequest,
InstructResponse,
)
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
TextGenerationAsyncServiceStub,
)
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
from yandex.cloud.operation.operation_service_pb2_grpc import (
OperationServiceStub,
)
except ImportError as e:
raise ImportError(
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
) from e
operation_api_url = "operation.api.cloud.yandex.net:443"
channel_credentials = grpc.ssl_channel_credentials()
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
request = InstructRequest(
model=self.model_name,
request_text=prompt,
generation_options=GenerationOptions(
temperature=DoubleValue(value=self.temperature),
max_tokens=Int64Value(value=self.max_tokens),
),
)
stub = TextGenerationAsyncServiceStub(channel)
if self.iam_token:
metadata = (("authorization", f"Bearer {self.iam_token}"),)
else:
metadata = (("authorization", f"Api-Key {self.api_key}"),)
operation = await stub.Instruct(request, metadata=metadata)
async with grpc.aio.secure_channel(
operation_api_url, channel_credentials
) as operation_channel:
operation_stub = OperationServiceStub(operation_channel)
while not operation.done:
await asyncio.sleep(1)
operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get(
operation_request, metadata=metadata
)
instruct_response = InstructResponse()
operation.response.Unpack(instruct_response)
text = instruct_response.alternatives[0].text
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text

View File

@@ -17,9 +17,12 @@ class OutputFixingParser(BaseOutputParser[T]):
return True
parser: BaseOutputParser[T]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""The LLMChain to use to retry the completion."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
@classmethod
def from_llm(
@@ -35,7 +38,7 @@ class OutputFixingParser(BaseOutputParser[T]):
llm: llm to use for fixing
parser: parser to use for parsing
prompt: prompt to use for fixing
max_retries: Maximum number of retries to parser.
max_retries: Maximum number of retries to parse.
Returns:
OutputFixingParser

View File

@@ -48,6 +48,8 @@ class RetryOutputParser(BaseOutputParser[T]):
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""The LLMChain to use to retry the completion."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
@classmethod
def from_llm(
@@ -55,11 +57,23 @@ class RetryOutputParser(BaseOutputParser[T]):
llm: BaseLanguageModel,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
max_retries: int = 1,
) -> RetryOutputParser[T]:
"""Create an OutputFixingParser from a language model and a parser.
Args:
llm: llm to use for fixing
parser: parser to use for parsing
prompt: prompt to use for fixing
max_retries: Maximum number of retries to parse.
Returns:
RetryOutputParser
"""
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
"""Parse the output of an LLM call using a wrapped parser.
@@ -71,15 +85,21 @@ class RetryOutputParser(BaseOutputParser[T]):
Returns:
The parsed completion.
"""
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion
)
parsed_completion = self.parser.parse(new_completion)
retries = 0
return parsed_completion
while retries <= self.max_retries:
try:
return self.parser.parse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion
)
raise OutputParserException("Failed to parse")
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
"""Parse the output of an LLM call using a wrapped parser.
@@ -91,15 +111,21 @@ class RetryOutputParser(BaseOutputParser[T]):
Returns:
The parsed completion.
"""
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException:
new_completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(), completion=completion
)
parsed_completion = self.parser.parse(new_completion)
retries = 0
return parsed_completion
while retries <= self.max_retries:
try:
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(), completion=completion
)
raise OutputParserException("Failed to parse")
def parse(self, completion: str) -> T:
raise NotImplementedError(
@@ -125,8 +151,12 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
"""
parser: BaseOutputParser[T]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""The LLMChain to use to retry the completion."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
@classmethod
def from_llm(
@@ -134,6 +164,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
llm: BaseLanguageModel,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
max_retries: int = 1,
) -> RetryWithErrorOutputParser[T]:
"""Create a RetryWithErrorOutputParser from an LLM.
@@ -141,6 +172,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
llm: The LLM to use to retry the completion.
parser: The parser to use to parse the output.
prompt: The prompt to use to retry the completion.
max_retries: The maximum number of times to retry the completion.
Returns:
A RetryWithErrorOutputParser.
@@ -148,29 +180,45 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain)
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
)
parsed_completion = self.parser.parse(new_completion)
retries = 0
return parsed_completion
while retries <= self.max_retries:
try:
return self.parser.parse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = self.retry_chain.run(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
raise OutputParserException("Failed to parse")
async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
try:
parsed_completion = self.parser.parse(completion)
except OutputParserException as e:
new_completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(), completion=completion, error=repr(e)
)
parsed_completion = self.parser.parse(new_completion)
retries = 0
return parsed_completion
while retries <= self.max_retries:
try:
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
raise OutputParserException("Failed to parse")
def parse(self, completion: str) -> T:
raise NotImplementedError(

View File

@@ -32,10 +32,8 @@ from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.retrievers.google_cloud_documentai_warehouse import (
GoogleDocumentAIWarehouseRetriever,
)
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
from langchain.retrievers.google_vertex_ai_search import (
GoogleCloudEnterpriseSearchRetriever,
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever,
)

View File

@@ -1,22 +0,0 @@
"""Retriever wrapper for Google Vertex AI Search.
DEPRECATED: Maintained for backwards compatibility.
"""
from typing import Any
from langchain.retrievers.google_vertex_ai_search import GoogleVertexAISearchRetriever
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
"""`Google Vertex Search API` retriever alias for backwards compatibility.
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
"""
def __init__(self, **data: Any):
import warnings
warnings.warn(
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
DeprecationWarning,
)
super().__init__(**data)

View File

@@ -25,10 +25,21 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
"""Vertex AI Search data store ID."""
location_id: str = "global"
"""Vertex AI Search data store location."""
serving_config_id: str = "default_config"
"""Vertex AI Search serving config ID."""
credentials: Any = None
"""The default custom credentials (google.auth.credentials.Credentials) to use
when making API calls. If not provided, credentials will be ascertained from
the environment."""
engine_data_type: int = Field(default=0, ge=0, le=2)
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
2 - Website data (with Advanced Website Indexing)
"""
_serving_config: str
"""Full path of serving config."""
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
@@ -144,6 +155,47 @@ class _BaseGoogleVertexAISearchRetriever(BaseModel):
return documents
def _convert_website_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data:
continue
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
doc_metadata["source"] = derived_struct_data.get("link", "")
chunk_type = "extractive_answers"
if chunk_type not in derived_struct_data:
continue
for chunk in derived_struct_data[chunk_type]:
documents.append(
Document(
page_content=chunk.get("content", ""), metadata=doc_metadata
)
)
if not documents:
print(
f"No {chunk_type} could be found.\n"
"Make sure that your data store is using Advanced Website Indexing.\n"
"https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing" # noqa: E501
)
return documents
class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
"""`Google Vertex AI Search` retriever.
@@ -153,8 +205,6 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
"""
serving_config_id: str = "default_config"
"""Vertex AI Search serving config ID."""
filter: Optional[str] = None
"""Filter expression."""
get_extractive_answers: bool = False
@@ -188,15 +238,7 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
Search will be based on the corrected query if found.
"""
# TODO: Add extra data type handling for type website
engine_data_type: int = Field(default=0, ge=0, le=1)
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
"""
_client: SearchServiceClient
_serving_config: str
class Config:
"""Configuration for this pydantic object."""
@@ -260,11 +302,16 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
)
elif self.engine_data_type == 1:
content_search_spec = None
elif self.engine_data_type == 2:
content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count,
)
)
else:
# TODO: Add extra data type handling for type website
raise NotImplementedError(
"Only engine data type 0 (Unstructured) or 1 (Structured)"
+ " are supported currently."
"Only data store type 0 (Unstructured), 1 (Structured),"
"or 2 (Website with Advanced Indexing) are supported currently."
+ f" Got {self.engine_data_type}"
)
@@ -305,11 +352,12 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
)
elif self.engine_data_type == 1:
documents = self._convert_structured_search_response(response.results)
elif self.engine_data_type == 2:
documents = self._convert_website_search_response(response.results)
else:
# TODO: Add extra data type handling for type website
raise NotImplementedError(
"Only engine data type 0 (Unstructured) or 1 (Structured)"
+ " are supported currently."
"Only data store type 0 (Unstructured), 1 (Structured),"
"or 2 (Website with Advanced Indexing) are supported currently."
+ f" Got {self.engine_data_type}"
)
@@ -321,6 +369,9 @@ class GoogleVertexAIMultiTurnSearchRetriever(
):
"""`Google Vertex AI Search` retriever for multi-turn conversations."""
conversation_id: str = "-"
"""Vertex AI Search Conversation ID."""
_client: ConversationalSearchServiceClient
class Config:
@@ -340,6 +391,20 @@ class GoogleVertexAIMultiTurnSearchRetriever(
credentials=self.credentials, client_options=self.client_options
)
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
if self.engine_data_type == 1:
raise NotImplementedError(
"Data store type 1 (Structured)"
"is not currently supported for multi-turn search."
+ f" Got {self.engine_data_type}"
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
@@ -351,11 +416,35 @@ class GoogleVertexAIMultiTurnSearchRetriever(
request = ConverseConversationRequest(
name=self._client.conversation_path(
self.project_id, self.location_id, self.data_store_id, "-"
self.project_id,
self.location_id,
self.data_store_id,
self.conversation_id,
),
serving_config=self._serving_config,
query=TextInput(input=query),
)
response = self._client.converse_conversation(request)
if self.engine_data_type == 2:
return self._convert_website_search_response(response.search_results)
return self._convert_unstructured_search_response(
response.search_results, "extractive_answers"
)
class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
"""`Google Vertex Search API` retriever alias for backwards compatibility.
DEPRECATED: Use `GoogleVertexAISearchRetriever` instead.
"""
def __init__(self, **data: Any):
import warnings
warnings.warn(
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
DeprecationWarning,
)
super().__init__(**data)

View File

@@ -39,7 +39,7 @@ class ElasticsearchTranslator(Visitor):
Comparator.LT: "lt",
Comparator.LTE: "lte",
Comparator.CONTAIN: "match",
Comparator.LIKE: "fuzzy",
Comparator.LIKE: "match",
}
return map_dict[func]
@@ -67,15 +67,19 @@ class ElasticsearchTranslator(Visitor):
}
}
if comparison.comparator == Comparator.LIKE:
if comparison.comparator == Comparator.CONTAIN:
return {
self._format_func(comparison.comparator): {
field: {"value": comparison.value, "fuzziness": "AUTO"}
field: {"query": comparison.value}
}
}
if comparison.comparator == Comparator.CONTAIN:
return {self._format_func(comparison.comparator): {field: comparison.value}}
if comparison.comparator == Comparator.LIKE:
return {
self._format_func(comparison.comparator): {
field: {"query": comparison.value, "fuzziness": "AUTO"}
}
}
# we assume that if the value is a string,
# we want to use the keyword field

File diff suppressed because it is too large Load Diff

View File

@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
import requests
import yaml
from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, ValidationError
from langchain.pydantic_v1 import ValidationError
logger = logging.getLogger(__name__)
@@ -38,288 +38,278 @@ class HTTPVerb(str, Enum):
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
if _PYDANTIC_MAJOR_VERSION == 1:
if TYPE_CHECKING:
from openapi_schema_pydantic import (
Components,
Operation,
Parameter,
PathItem,
Paths,
Reference,
RequestBody,
Schema,
if TYPE_CHECKING:
from openapi_pydantic import (
Components,
Operation,
Parameter,
PathItem,
Paths,
Reference,
RequestBody,
Schema,
)
try:
from openapi_pydantic import OpenAPI
except ImportError:
OpenAPI = object # type: ignore
class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes mis-formatted parts of the spec."""
openapi: str = "3.1.0" # overriding overly restrictive type from parent class
@property
def _paths_strict(self) -> Paths:
if not self.paths:
raise ValueError("No paths found in spec")
return self.paths
def _get_path_strict(self, path: str) -> PathItem:
path_item = self._paths_strict.get(path)
if not path_item:
raise ValueError(f"No path found for {path}")
return path_item
@property
def _components_strict(self) -> Components:
"""Get components or err."""
if self.components is None:
raise ValueError("No components found in spec. ")
return self.components
@property
def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
"""Get parameters or err."""
parameters = self._components_strict.parameters
if parameters is None:
raise ValueError("No parameters found in spec. ")
return parameters
@property
def _schemas_strict(self) -> Dict[str, Schema]:
"""Get the dictionary of schemas or err."""
schemas = self._components_strict.schemas
if schemas is None:
raise ValueError("No schemas found in spec. ")
return schemas
@property
def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]:
"""Get the request body or err."""
request_bodies = self._components_strict.requestBodies
if request_bodies is None:
raise ValueError("No request body found in spec. ")
return request_bodies
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
"""Get a parameter (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
parameters = self._parameters_strict
if ref_name not in parameters:
raise ValueError(f"No parameter found for {ref_name}")
return parameters[ref_name]
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
"""Get the root reference or err."""
from openapi_pydantic import Reference
parameter = self._get_referenced_parameter(ref)
while isinstance(parameter, Reference):
parameter = self._get_referenced_parameter(parameter)
return parameter
def get_referenced_schema(self, ref: Reference) -> Schema:
"""Get a schema (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
schemas = self._schemas_strict
if ref_name not in schemas:
raise ValueError(f"No schema found for {ref_name}")
return schemas[ref_name]
def get_schema(self, schema: Union[Reference, Schema]) -> Schema:
from openapi_pydantic import Reference
if isinstance(schema, Reference):
return self.get_referenced_schema(schema)
return schema
def _get_root_referenced_schema(self, ref: Reference) -> Schema:
"""Get the root reference or err."""
from openapi_pydantic import Reference
schema = self.get_referenced_schema(ref)
while isinstance(schema, Reference):
schema = self.get_referenced_schema(schema)
return schema
def _get_referenced_request_body(
self, ref: Reference
) -> Optional[Union[Reference, RequestBody]]:
"""Get a request body (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
request_bodies = self._request_bodies_strict
if ref_name not in request_bodies:
raise ValueError(f"No request body found for {ref_name}")
return request_bodies[ref_name]
def _get_root_referenced_request_body(
self, ref: Reference
) -> Optional[RequestBody]:
"""Get the root request Body or err."""
from openapi_pydantic import Reference
request_body = self._get_referenced_request_body(ref)
while isinstance(request_body, Reference):
request_body = self._get_referenced_request_body(request_body)
return request_body
@staticmethod
def _alert_unsupported_spec(obj: dict) -> None:
"""Alert if the spec is not supported."""
warning_message = (
" This may result in degraded performance."
+ " Convert your OpenAPI spec to 3.1.* spec"
+ " for better support."
)
try:
from openapi_schema_pydantic import OpenAPI
except ImportError:
OpenAPI = object # type: ignore
class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes mis-formatted parts of the spec."""
@property
def _paths_strict(self) -> Paths:
if not self.paths:
raise ValueError("No paths found in spec")
return self.paths
def _get_path_strict(self, path: str) -> PathItem:
path_item = self._paths_strict.get(path)
if not path_item:
raise ValueError(f"No path found for {path}")
return path_item
@property
def _components_strict(self) -> Components:
"""Get components or err."""
if self.components is None:
raise ValueError("No components found in spec. ")
return self.components
@property
def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
"""Get parameters or err."""
parameters = self._components_strict.parameters
if parameters is None:
raise ValueError("No parameters found in spec. ")
return parameters
@property
def _schemas_strict(self) -> Dict[str, Schema]:
"""Get the dictionary of schemas or err."""
schemas = self._components_strict.schemas
if schemas is None:
raise ValueError("No schemas found in spec. ")
return schemas
@property
def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]:
"""Get the request body or err."""
request_bodies = self._components_strict.requestBodies
if request_bodies is None:
raise ValueError("No request body found in spec. ")
return request_bodies
def _get_referenced_parameter(
self, ref: Reference
) -> Union[Parameter, Reference]:
"""Get a parameter (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
parameters = self._parameters_strict
if ref_name not in parameters:
raise ValueError(f"No parameter found for {ref_name}")
return parameters[ref_name]
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
"""Get the root reference or err."""
from openapi_schema_pydantic import Reference
parameter = self._get_referenced_parameter(ref)
while isinstance(parameter, Reference):
parameter = self._get_referenced_parameter(parameter)
return parameter
def get_referenced_schema(self, ref: Reference) -> Schema:
"""Get a schema (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
schemas = self._schemas_strict
if ref_name not in schemas:
raise ValueError(f"No schema found for {ref_name}")
return schemas[ref_name]
def get_schema(self, schema: Union[Reference, Schema]) -> Schema:
from openapi_schema_pydantic import Reference
if isinstance(schema, Reference):
return self.get_referenced_schema(schema)
return schema
def _get_root_referenced_schema(self, ref: Reference) -> Schema:
"""Get the root reference or err."""
from openapi_schema_pydantic import Reference
schema = self.get_referenced_schema(ref)
while isinstance(schema, Reference):
schema = self.get_referenced_schema(schema)
return schema
def _get_referenced_request_body(
self, ref: Reference
) -> Optional[Union[Reference, RequestBody]]:
"""Get a request body (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
request_bodies = self._request_bodies_strict
if ref_name not in request_bodies:
raise ValueError(f"No request body found for {ref_name}")
return request_bodies[ref_name]
def _get_root_referenced_request_body(
self, ref: Reference
) -> Optional[RequestBody]:
"""Get the root request Body or err."""
from openapi_schema_pydantic import Reference
request_body = self._get_referenced_request_body(ref)
while isinstance(request_body, Reference):
request_body = self._get_referenced_request_body(request_body)
return request_body
@staticmethod
def _alert_unsupported_spec(obj: dict) -> None:
"""Alert if the spec is not supported."""
warning_message = (
" This may result in degraded performance."
+ " Convert your OpenAPI spec to 3.1.* spec"
+ " for better support."
)
swagger_version = obj.get("swagger")
openapi_version = obj.get("openapi")
if isinstance(openapi_version, str):
if openapi_version != "3.1.0":
logger.warning(
f"Attempting to load an OpenAPI {openapi_version}"
f" spec. {warning_message}"
)
else:
pass
elif isinstance(swagger_version, str):
swagger_version = obj.get("swagger")
openapi_version = obj.get("openapi")
if isinstance(openapi_version, str):
if openapi_version != "3.1.0":
logger.warning(
f"Attempting to load a Swagger {swagger_version}"
f"Attempting to load an OpenAPI {openapi_version}"
f" spec. {warning_message}"
)
else:
raise ValueError(
"Attempting to load an unsupported spec:"
f"\n\n{obj}\n{warning_message}"
)
pass
elif isinstance(swagger_version, str):
logger.warning(
f"Attempting to load a Swagger {swagger_version}"
f" spec. {warning_message}"
)
else:
raise ValueError(
"Attempting to load an unsupported spec:"
f"\n\n{obj}\n{warning_message}"
)
@classmethod
def parse_obj(cls, obj: dict) -> OpenAPISpec:
try:
cls._alert_unsupported_spec(obj)
return super().parse_obj(obj)
except ValidationError as e:
# We are handling possibly misconfigured specs and
# want to do a best-effort job to get a reasonable interface out of it.
new_obj = copy.deepcopy(obj)
for error in e.errors():
keys = error["loc"]
item = new_obj
for key in keys[:-1]:
item = item[key]
item.pop(keys[-1], None)
return cls.parse_obj(new_obj)
@classmethod
def parse_obj(cls, obj: dict) -> OpenAPISpec:
try:
cls._alert_unsupported_spec(obj)
return super().parse_obj(obj)
except ValidationError as e:
# We are handling possibly misconfigured specs and
# want to do a best-effort job to get a reasonable interface out of it.
new_obj = copy.deepcopy(obj)
for error in e.errors():
keys = error["loc"]
item = new_obj
for key in keys[:-1]:
item = item[key]
item.pop(keys[-1], None)
return cls.parse_obj(new_obj)
@classmethod
def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
"""Get an OpenAPI spec from a dict."""
return cls.parse_obj(spec_dict)
@classmethod
def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
"""Get an OpenAPI spec from a dict."""
return cls.parse_obj(spec_dict)
@classmethod
def from_text(cls, text: str) -> OpenAPISpec:
"""Get an OpenAPI spec from a text."""
try:
spec_dict = json.loads(text)
except json.JSONDecodeError:
spec_dict = yaml.safe_load(text)
return cls.from_spec_dict(spec_dict)
@classmethod
def from_text(cls, text: str) -> OpenAPISpec:
"""Get an OpenAPI spec from a text."""
try:
spec_dict = json.loads(text)
except json.JSONDecodeError:
spec_dict = yaml.safe_load(text)
return cls.from_spec_dict(spec_dict)
@classmethod
def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
"""Get an OpenAPI spec from a file path."""
path_ = path if isinstance(path, Path) else Path(path)
if not path_.exists():
raise FileNotFoundError(f"{path} does not exist")
with path_.open("r") as f:
return cls.from_text(f.read())
@classmethod
def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
"""Get an OpenAPI spec from a file path."""
path_ = path if isinstance(path, Path) else Path(path)
if not path_.exists():
raise FileNotFoundError(f"{path} does not exist")
with path_.open("r") as f:
return cls.from_text(f.read())
@classmethod
def from_url(cls, url: str) -> OpenAPISpec:
"""Get an OpenAPI spec from a URL."""
response = requests.get(url)
return cls.from_text(response.text)
@classmethod
def from_url(cls, url: str) -> OpenAPISpec:
"""Get an OpenAPI spec from a URL."""
response = requests.get(url)
return cls.from_text(response.text)
@property
def base_url(self) -> str:
"""Get the base url."""
return self.servers[0].url
@property
def base_url(self) -> str:
"""Get the base url."""
return self.servers[0].url
def get_methods_for_path(self, path: str) -> List[str]:
"""Return a list of valid methods for the specified path."""
from openapi_schema_pydantic import Operation
def get_methods_for_path(self, path: str) -> List[str]:
"""Return a list of valid methods for the specified path."""
from openapi_pydantic import Operation
path_item = self._get_path_strict(path)
results = []
for method in HTTPVerb:
operation = getattr(path_item, method.value, None)
if isinstance(operation, Operation):
results.append(method.value)
return results
path_item = self._get_path_strict(path)
results = []
for method in HTTPVerb:
operation = getattr(path_item, method.value, None)
if isinstance(operation, Operation):
results.append(method.value)
return results
def get_parameters_for_path(self, path: str) -> List[Parameter]:
from openapi_schema_pydantic import Reference
def get_parameters_for_path(self, path: str) -> List[Parameter]:
from openapi_pydantic import Reference
path_item = self._get_path_strict(path)
parameters = []
if not path_item.parameters:
return []
for parameter in path_item.parameters:
path_item = self._get_path_strict(path)
parameters = []
if not path_item.parameters:
return []
for parameter in path_item.parameters:
if isinstance(parameter, Reference):
parameter = self._get_root_referenced_parameter(parameter)
parameters.append(parameter)
return parameters
def get_operation(self, path: str, method: str) -> Operation:
"""Get the operation object for a given path and HTTP method."""
from openapi_pydantic import Operation
path_item = self._get_path_strict(path)
operation_obj = getattr(path_item, method, None)
if not isinstance(operation_obj, Operation):
raise ValueError(f"No {method} method found for {path}")
return operation_obj
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
"""Get the components for a given operation."""
from openapi_pydantic import Reference
parameters = []
if operation.parameters:
for parameter in operation.parameters:
if isinstance(parameter, Reference):
parameter = self._get_root_referenced_parameter(parameter)
parameters.append(parameter)
return parameters
return parameters
def get_operation(self, path: str, method: str) -> Operation:
"""Get the operation object for a given path and HTTP method."""
from openapi_schema_pydantic import Operation
def get_request_body_for_operation(
self, operation: Operation
) -> Optional[RequestBody]:
"""Get the request body for a given operation."""
from openapi_pydantic import Reference
path_item = self._get_path_strict(path)
operation_obj = getattr(path_item, method, None)
if not isinstance(operation_obj, Operation):
raise ValueError(f"No {method} method found for {path}")
return operation_obj
request_body = operation.requestBody
if isinstance(request_body, Reference):
request_body = self._get_root_referenced_request_body(request_body)
return request_body
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
"""Get the components for a given operation."""
from openapi_schema_pydantic import Reference
parameters = []
if operation.parameters:
for parameter in operation.parameters:
if isinstance(parameter, Reference):
parameter = self._get_root_referenced_parameter(parameter)
parameters.append(parameter)
return parameters
def get_request_body_for_operation(
self, operation: Operation
) -> Optional[RequestBody]:
"""Get the request body for a given operation."""
from openapi_schema_pydantic import Reference
request_body = operation.requestBody
if isinstance(request_body, Reference):
request_body = self._get_root_referenced_request_body(request_body)
return request_body
@staticmethod
def get_cleaned_operation_id(
operation: Operation, path: str, method: str
) -> str:
"""Get a cleaned operation id from an operation id."""
operation_id = operation.operationId
if operation_id is None:
# Replace all punctuation of any kind with underscore
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
operation_id = f"{path}_{method}"
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
else:
class OpenAPISpec: # type: ignore[no-redef]
"""Shim for pydantic version >=2"""
def __init__(self) -> None:
raise NotImplementedError("Only supported for pydantic version 1")
@staticmethod
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
"""Get a cleaned operation id from an operation id."""
operation_id = operation.operationId
if operation_id is None:
# Replace all punctuation of any kind with underscore
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
operation_id = f"{path}_{method}"
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")

View File

@@ -156,12 +156,22 @@ class ElasticVectorSearch(VectorStore):
self.index_name = index_name
_ssl_verify = ssl_verify or {}
try:
self.client = elasticsearch.Elasticsearch(elasticsearch_url, **_ssl_verify)
self.client = elasticsearch.Elasticsearch(
elasticsearch_url,
**_ssl_verify,
headers={"user-agent": self.get_user_agent()},
)
except ValueError as e:
raise ValueError(
f"Your elasticsearch client string is mis-formatted. Got error: {e} "
)
@staticmethod
def get_user_agent() -> str:
from langchain import __version__
return f"langchain-py-dvs/{__version__}"
@property
def embeddings(self) -> Embeddings:
return self.embedding

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import operator
import os
import pickle
@@ -15,6 +16,7 @@ from typing import (
Optional,
Sized,
Tuple,
Union,
)
import numpy as np
@@ -26,6 +28,8 @@ from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
logger = logging.getLogger(__name__)
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
"""
@@ -82,7 +86,7 @@ class FAISS(VectorStore):
def __init__(
self,
embedding_function: Callable,
embedding_function: Union[Callable, Embeddings],
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str],
@@ -91,6 +95,11 @@ class FAISS(VectorStore):
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
):
"""Initialize with necessary components."""
if not isinstance(embedding_function, Embeddings):
logger.warning(
"`embedding_function` is expected to be an Embeddings object, support "
"for passing in a function will soon be removed."
)
self.embedding_function = embedding_function
self.index = index
self.docstore = docstore
@@ -108,6 +117,26 @@ class FAISS(VectorStore):
)
)
@property
def embeddings(self) -> Optional[Embeddings]:
return (
self.embedding_function
if isinstance(self.embedding_function, Embeddings)
else None
)
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
if isinstance(self.embedding_function, Embeddings):
return self.embedding_function.embed_documents(texts)
else:
return [self.embedding_function(text) for text in texts]
def _embed_query(self, text: str) -> List[float]:
if isinstance(self.embedding_function, Embeddings):
return self.embedding_function.embed_query(text)
else:
return self.embedding_function(text)
def __add(
self,
texts: Iterable[str],
@@ -163,7 +192,8 @@ class FAISS(VectorStore):
Returns:
List of ids from adding the texts into the vectorstore.
"""
embeddings = [self.embedding_function(text) for text in texts]
texts = list(texts)
embeddings = self._embed_documents(texts)
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
def add_embeddings(
@@ -272,7 +302,7 @@ class FAISS(VectorStore):
List of documents most similar to the query text with
L2 distance in float. Lower score represents more similarity.
"""
embedding = self.embedding_function(query)
embedding = self._embed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding,
k,
@@ -465,7 +495,7 @@ class FAISS(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
embedding = self.embedding_function(query)
embedding = self._embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding,
k=k,
@@ -561,7 +591,7 @@ class FAISS(VectorStore):
# Default to L2, currently other metric types not initialized.
index = faiss.IndexFlatL2(len(embeddings[0]))
vecstore = cls(
embedding.embed_query,
embedding,
index,
InMemoryDocstore(),
{},
@@ -696,9 +726,7 @@ class FAISS(VectorStore):
# load docstore and index_to_docstore_id
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
docstore, index_to_docstore_id = pickle.load(f)
return cls(
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
)
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
def serialize_to_bytes(self) -> bytes:
"""Serialize FAISS index, docstore, and index_to_docstore_id to bytes."""
@@ -713,9 +741,7 @@ class FAISS(VectorStore):
) -> FAISS:
"""Deserialize FAISS index, docstore, and index_to_docstore_id from bytes."""
index, docstore, index_to_docstore_id = pickle.loads(serialized)
return cls(
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
)
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""

View File

@@ -209,6 +209,8 @@ class Weaviate(VectorStore):
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("tenant"):
query_obj = query_obj.with_tenant(kwargs.get("tenant"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
result = query_obj.with_near_text(content).with_limit(k).do()
@@ -228,6 +230,8 @@ class Weaviate(VectorStore):
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("tenant"):
query_obj = query_obj.with_tenant(kwargs.get("tenant"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
result = query_obj.with_near_vector(vector).with_limit(k).do()
@@ -304,6 +308,8 @@ class Weaviate(VectorStore):
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("tenant"):
query_obj = query_obj.with_tenant(kwargs.get("tenant"))
results = (
query_obj.with_additional("vector")
.with_near_vector(vector)
@@ -343,6 +349,8 @@ class Weaviate(VectorStore):
query_obj = self._client.query.get(self._index_name, self._query_attrs)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("tenant"):
query_obj = query_obj.with_tenant(kwargs.get("tenant"))
embedded_query = self._embedding.embed_query(query)
if not self._by_text:

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain"
version = "0.0.317"
version = "0.0.318"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"
@@ -20,7 +20,7 @@ PyYAML = ">=5.3"
numpy = "^1"
azure-core = {version = "^1.26.4", optional=true}
tqdm = {version = ">=4.48.0", optional = true}
openapi-schema-pydantic = {version = "^1.2", optional = true}
openapi-pydantic = {version = "^0.3.2", optional = true}
faiss-cpu = {version = "^1", optional = true}
wikipedia = {version = "^1", optional = true}
elasticsearch = {version = "^8", optional = true}
@@ -359,7 +359,7 @@ extended_testing = [
"xata",
"xmltodict",
"faiss-cpu",
"openapi-schema-pydantic",
"openapi-pydantic",
"markdownify",
"arxiv",
"dashvector",

View File

@@ -0,0 +1,82 @@
"""Test AliCloud Pai Eas Chat Model."""
import os
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
HumanMessage,
LLMResult,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_pai_eas_call() -> None:
chat = PaiEasChatEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
)
response = chat(messages=[HumanMessage(content="Say foo:")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_multiple_history() -> None:
"""Tests multiple history works."""
chat = PaiEasChatEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
)
response = chat(
messages=[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_stream() -> None:
"""Test that stream works."""
chat = PaiEasChatEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
streaming=True,
)
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
response = chat(
messages=[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="Who are you?"),
],
stream=True,
callbacks=callback_manager,
)
assert callback_handler.llm_streams > 0
assert isinstance(response.content, str)
def test_multiple_messages() -> None:
"""Tests multiple messages works."""
chat = PaiEasChatEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
)
message = HumanMessage(content="Hi, how are you.")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

View File

@@ -0,0 +1,59 @@
"""Test PaiEasEndpoint API wrapper."""
import os
from typing import Generator
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
def test_pai_eas_v1_call() -> None:
"""Test valid call to PAI-EAS Service."""
llm = PaiEasEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
version="1.0",
)
output = llm("Say foo:")
assert isinstance(output, str)
def test_pai_eas_v2_call() -> None:
llm = PaiEasEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
version="2.0",
)
output = llm("Say foo:")
assert isinstance(output, str)
def test_pai_eas_v1_streaming() -> None:
"""Test streaming call to PAI-EAS Service."""
llm = PaiEasEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
version="1.0",
)
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
stream_results_string = ""
assert isinstance(generator, Generator)
for chunk in generator:
assert isinstance(chunk, str)
stream_results_string = chunk
assert len(stream_results_string.strip()) > 1
def test_pai_eas_v2_streaming() -> None:
llm = PaiEasEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
version="2.0",
)
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
stream_results_string = ""
assert isinstance(generator, Generator)
for chunk in generator:
assert isinstance(chunk, str)
stream_results_string = chunk
assert len(stream_results_string.strip()) > 1

View File

@@ -15,10 +15,8 @@ import os
import pytest
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
from langchain.retrievers.google_vertex_ai_search import (
GoogleCloudEnterpriseSearchRetriever,
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever,
)

View File

@@ -531,7 +531,7 @@ class TestElasticsearch:
},
}
},
settings={"index": {"default_pipeline": "pipeline"}},
settings={"index": {"default_pipeline": "test_pipeline"}},
)
# adding documents to the index

View File

@@ -49,14 +49,14 @@ def test_visit_comparison_range_lte() -> None:
def test_visit_comparison_range_match() -> None:
comp = Comparison(comparator=Comparator.CONTAIN, attribute="foo", value="1")
expected = {"match": {"metadata.foo": "1"}}
expected = {"match": {"metadata.foo": {"query": "1"}}}
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_comparison_range_like() -> None:
comp = Comparison(comparator=Comparator.LIKE, attribute="foo", value="bar")
expected = {"fuzzy": {"metadata.foo": {"value": "bar", "fuzziness": "AUTO"}}}
expected = {"match": {"metadata.foo": {"query": "bar", "fuzziness": "AUTO"}}}
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
@@ -200,9 +200,9 @@ def test_visit_structured_query_complex() -> None:
"should": [
{"range": {"metadata.bar": {"lt": 1}}},
{
"fuzzy": {
"match": {
"metadata.bar": {
"value": "10",
"query": "10",
"fuzziness": "AUTO",
}
}

View File

@@ -2163,19 +2163,19 @@
dict({
'anyOf': list([
dict({
'$ref': '#/definitions/HumanMessage',
'$ref': '#/definitions/AIMessage',
}),
dict({
'$ref': '#/definitions/AIMessage',
'$ref': '#/definitions/HumanMessage',
}),
dict({
'$ref': '#/definitions/ChatMessage',
}),
dict({
'$ref': '#/definitions/FunctionMessage',
'$ref': '#/definitions/SystemMessage',
}),
dict({
'$ref': '#/definitions/SystemMessage',
'$ref': '#/definitions/FunctionMessage',
}),
]),
'definitions': dict({

View File

@@ -7,7 +7,7 @@ from typing import Iterable, List, Tuple
import pytest
# Keep at top of file to ensure that pydantic test can be skipped before
# pydantic v1 related imports are attempted by openapi_schema_pydantic.
# pydantic v1 related imports are attempted by openapi_pydantic.
from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION
if _PYDANTIC_MAJOR_VERSION != 1:
@@ -78,7 +78,7 @@ def http_paths_and_methods() -> List[Tuple[str, OpenAPISpec, str, str]]:
return http_paths_and_methods
@pytest.mark.requires("openapi_schema_pydantic")
@pytest.mark.requires("openapi_pydantic")
def test_parse_api_operations() -> None:
"""Test the APIOperation class."""
for spec_name, spec, path, method in http_paths_and_methods():
@@ -88,21 +88,21 @@ def test_parse_api_operations() -> None:
raise AssertionError(f"Error processing {spec_name}: {e} ") from e
@pytest.mark.requires("openapi_schema_pydantic")
@pytest.mark.requires("openapi_pydantic")
@pytest.fixture
def raw_spec() -> OpenAPISpec:
"""Return a raw OpenAPI spec."""
from openapi_schema_pydantic import Info
from openapi_pydantic import Info
return OpenAPISpec(
info=Info(title="Test API", version="1.0.0"),
)
@pytest.mark.requires("openapi_schema_pydantic")
@pytest.mark.requires("openapi_pydantic")
def test_api_request_body_from_request_body_with_ref(raw_spec: OpenAPISpec) -> None:
"""Test instantiating APIRequestBody from RequestBody with a reference."""
from openapi_schema_pydantic import (
from openapi_pydantic import (
Components,
MediaType,
Reference,
@@ -140,10 +140,10 @@ def test_api_request_body_from_request_body_with_ref(raw_spec: OpenAPISpec) -> N
assert api_request_body.media_type == "application/json"
@pytest.mark.requires("openapi_schema_pydantic")
@pytest.mark.requires("openapi_pydantic")
def test_api_request_body_from_request_body_with_schema(raw_spec: OpenAPISpec) -> None:
"""Test instantiating APIRequestBody from RequestBody with a schema."""
from openapi_schema_pydantic import (
from openapi_pydantic import (
MediaType,
RequestBody,
Schema,
@@ -171,9 +171,9 @@ def test_api_request_body_from_request_body_with_schema(raw_spec: OpenAPISpec) -
assert api_request_body.media_type == "application/json"
@pytest.mark.requires("openapi_schema_pydantic")
@pytest.mark.requires("openapi_pydantic")
def test_api_request_body_property_from_schema(raw_spec: OpenAPISpec) -> None:
from openapi_schema_pydantic import (
from openapi_pydantic import (
Components,
Reference,
Schema,