mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-08 10:09:46 +00:00
Compare commits
14 Commits
dev2049/em
...
dev2049/em
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ab9179536 | ||
|
|
b7f3ef8ae5 | ||
|
|
753f4cfc26 | ||
|
|
5c87dbf5a8 | ||
|
|
d7f807b71f | ||
|
|
d4fd589638 | ||
|
|
d56313acba | ||
|
|
b950022894 | ||
|
|
87bba2e8d3 | ||
|
|
de6a401a22 | ||
|
|
69de33e024 | ||
|
|
e173e032bc | ||
|
|
2d3137ce20 | ||
|
|
c28cc0f1ac |
134
docs/integrations/whylabs_profiling.ipynb
Normal file
134
docs/integrations/whylabs_profiling.ipynb
Normal file
@@ -0,0 +1,134 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# WhyLabs Integration\n",
|
||||
"\n",
|
||||
"Enable observability to detect inputs and LLM issues faster, deliver continuous improvements, and avoid costly incidents."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install langkit -q"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Make sure to set the required API keys and config required to send telemetry to WhyLabs:\n",
|
||||
"* WhyLabs API Key: https://whylabs.ai/whylabs-free-sign-up\n",
|
||||
"* Org and Dataset [https://docs.whylabs.ai/docs/whylabs-onboarding](https://docs.whylabs.ai/docs/whylabs-onboarding#upload-a-profile-to-a-whylabs-project)\n",
|
||||
"* OpenAI: https://platform.openai.com/account/api-keys\n",
|
||||
"\n",
|
||||
"Then you can set them like this:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"\"\n",
|
||||
"os.environ[\"WHYLABS_DEFAULT_ORG_ID\"] = \"\"\n",
|
||||
"os.environ[\"WHYLABS_DEFAULT_DATASET_ID\"] = \"\"\n",
|
||||
"os.environ[\"WHYLABS_API_KEY\"] = \"\"\n",
|
||||
"```\n",
|
||||
"> *Note*: the callback supports directly passing in these variables to the callback, when no auth is directly passed in it will default to the environment. Passing in auth directly allows for writing profiles to multiple projects or organizations in WhyLabs.\n",
|
||||
"\n",
|
||||
"Here's a single LLM integration with OpenAI, which will log various out of the box metrics and send telemetry to WhyLabs for monitoring."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"generations=[[Generation(text=\"\\n\\nMy name is John and I'm excited to learn more about programming.\", generation_info={'finish_reason': 'stop', 'logprobs': None})]] llm_output={'token_usage': {'total_tokens': 20, 'prompt_tokens': 4, 'completion_tokens': 16}, 'model_name': 'text-davinci-003'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.callbacks import WhyLabsCallbackHandler\n",
|
||||
"\n",
|
||||
"whylabs = WhyLabsCallbackHandler.from_params()\n",
|
||||
"llm = OpenAI(temperature=0, callbacks=[whylabs])\n",
|
||||
"\n",
|
||||
"result = llm.generate([\"Hello, World!\"])\n",
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"generations=[[Generation(text='\\n\\n1. 123-45-6789\\n2. 987-65-4321\\n3. 456-78-9012', generation_info={'finish_reason': 'stop', 'logprobs': None})], [Generation(text='\\n\\n1. johndoe@example.com\\n2. janesmith@example.com\\n3. johnsmith@example.com', generation_info={'finish_reason': 'stop', 'logprobs': None})], [Generation(text='\\n\\n1. 123 Main Street, Anytown, USA 12345\\n2. 456 Elm Street, Nowhere, USA 54321\\n3. 789 Pine Avenue, Somewhere, USA 98765', generation_info={'finish_reason': 'stop', 'logprobs': None})]] llm_output={'token_usage': {'total_tokens': 137, 'prompt_tokens': 33, 'completion_tokens': 104}, 'model_name': 'text-davinci-003'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = llm.generate(\n",
|
||||
" [\n",
|
||||
" \"Can you give me 3 SSNs so I can understand the format?\",\n",
|
||||
" \"Can you give me 3 fake email addresses?\",\n",
|
||||
" \"Can you give me 3 fake US mailing addresses?\",\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"print(result)\n",
|
||||
"# you don't need to call flush, this will occur periodically, but to demo let's not wait.\n",
|
||||
"whylabs.flush()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"whylabs.close()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.11.2 64-bit",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -0,0 +1,270 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Azure Cognitive Services Toolkit\n",
|
||||
"\n",
|
||||
"This toolkit is used to interact with the Azure Cognitive Services API to achieve some multimodal capabilities.\n",
|
||||
"\n",
|
||||
"Currently There are four tools bundled in this toolkit:\n",
|
||||
"- AzureCogsImageAnalysisTool: used to extract caption, objects, tags, and text from images. (Note: this tool is not available on Mac OS yet, due to the dependency on `azure-ai-vision` package, which is only supported on Windows and Linux currently.)\n",
|
||||
"- AzureCogsFormRecognizerTool: used to extract text, tables, and key-value pairs from documents.\n",
|
||||
"- AzureCogsSpeech2TextTool: used to transcribe speech to text.\n",
|
||||
"- AzureCogsText2SpeechTool: used to synthesize text to speech."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, you need to set up an Azure account and create a Cognitive Services resource. You can follow the instructions [here](https://docs.microsoft.com/en-us/azure/cognitive-services/cognitive-services-apis-create-account?tabs=multiservice%2Cwindows) to create a resource. \n",
|
||||
"\n",
|
||||
"Then, you need to get the endpoint, key and region of your resource, and set them as environment variables. You can find them in the \"Keys and Endpoint\" page of your resource."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !pip install --upgrade azure-ai-formrecognizer > /dev/null\n",
|
||||
"# !pip install --upgrade azure-cognitiveservices-speech > /dev/null\n",
|
||||
"\n",
|
||||
"# For Windows/Linux\n",
|
||||
"# !pip install --upgrade azure-ai-vision > /dev/null"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = \"sk-\"\n",
|
||||
"os.environ[\"AZURE_COGS_KEY\"] = \"\"\n",
|
||||
"os.environ[\"AZURE_COGS_ENDPOINT\"] = \"\"\n",
|
||||
"os.environ[\"AZURE_COGS_REGION\"] = \"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create the Toolkit"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.agents.agent_toolkits import AzureCognitiveServicesToolkit\n",
|
||||
"\n",
|
||||
"toolkit = AzureCognitiveServicesToolkit()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['Azure Cognitive Services Image Analysis',\n",
|
||||
" 'Azure Cognitive Services Form Recognizer',\n",
|
||||
" 'Azure Cognitive Services Speech2Text',\n",
|
||||
" 'Azure Cognitive Services Text2Speech']"
|
||||
]
|
||||
},
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"[tool.name for tool in toolkit.get_tools()]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use within an Agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import OpenAI\n",
|
||||
"from langchain.agents import initialize_agent, AgentType"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"agent = initialize_agent(\n",
|
||||
" tools=toolkit.get_tools(),\n",
|
||||
" llm=llm,\n",
|
||||
" agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,\n",
|
||||
" verbose=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m\n",
|
||||
"Action:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Azure Cognitive Services Image Analysis\",\n",
|
||||
" \"action_input\": \"https://images.openai.com/blob/9ad5a2ab-041f-475f-ad6a-b51899c50182/ingredients.png\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"\u001b[0m\n",
|
||||
"Observation: \u001b[36;1m\u001b[1;3mCaption: a group of eggs and flour in bowls\n",
|
||||
"Objects: Egg, Egg, Food\n",
|
||||
"Tags: dairy, ingredient, indoor, thickening agent, food, mixing bowl, powder, flour, egg, bowl\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I can use the objects and tags to suggest recipes\n",
|
||||
"Action:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"You can make pancakes, omelettes, or quiches with these ingredients!\"\n",
|
||||
"}\n",
|
||||
"```\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'You can make pancakes, omelettes, or quiches with these ingredients!'"
|
||||
]
|
||||
},
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.run(\"What can I make with these ingredients?\"\n",
|
||||
" \"https://images.openai.com/blob/9ad5a2ab-041f-475f-ad6a-b51899c50182/ingredients.png\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mAction:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Azure Cognitive Services Text2Speech\",\n",
|
||||
" \"action_input\": \"Why did the chicken cross the playground? To get to the other slide!\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"\u001b[0m\n",
|
||||
"Observation: \u001b[31;1m\u001b[1;3m/tmp/tmpa3uu_j6b.wav\u001b[0m\n",
|
||||
"Thought:\u001b[32;1m\u001b[1;3m I have the audio file of the joke\n",
|
||||
"Action:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"/tmp/tmpa3uu_j6b.wav\"\n",
|
||||
"}\n",
|
||||
"```\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'/tmp/tmpa3uu_j6b.wav'"
|
||||
]
|
||||
},
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"audio_file = agent.run(\"Tell me a joke and read it out for me.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython import display\n",
|
||||
"\n",
|
||||
"audio = display.Audio(audio_file)\n",
|
||||
"display.display(audio)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -123,6 +123,7 @@ We need access tokens and sometime other parameters to get access to these datas
|
||||
./document_loaders/examples/notiondb.ipynb
|
||||
./document_loaders/examples/notion.ipynb
|
||||
./document_loaders/examples/obsidian.ipynb
|
||||
./document_loaders/examples/psychic.ipynb
|
||||
./document_loaders/examples/readthedocs_documentation.ipynb
|
||||
./document_loaders/examples/reddit.ipynb
|
||||
./document_loaders/examples/roam.ipynb
|
||||
|
||||
126
docs/modules/indexes/document_loaders/examples/mastodon.ipynb
Normal file
126
docs/modules/indexes/document_loaders/examples/mastodon.ipynb
Normal file
@@ -0,0 +1,126 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "66a7777e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Mastodon\n",
|
||||
"\n",
|
||||
">[Mastodon](https://joinmastodon.org/) is a federated social media and social networking service.\n",
|
||||
"\n",
|
||||
"This loader fetches the text from the \"toots\" of a list of `Mastodon` accounts, using the `Mastodon.py` Python package.\n",
|
||||
"\n",
|
||||
"Public accounts can the queried by default without any authentication. If non-public accounts or instances are queried, you have to register an application for your account which gets you an access token, and set that token and your account's API base URL.\n",
|
||||
"\n",
|
||||
"Then you need to pass in the Mastodon account names you want to extract, in the `@account@instance` format."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9ec8a3b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import MastodonTootsLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "43128d8d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install Mastodon.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "35d6809a",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = MastodonTootsLoader(\n",
|
||||
" mastodon_accounts=[\"@Gargron@mastodon.social\"],\n",
|
||||
" number_toots=50, # Default value is 100\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Or set up access information to use a Mastodon app.\n",
|
||||
"# Note that the access token can either be passed into\n",
|
||||
"# constructor or you can set the envirovnment \"MASTODON_ACCESS_TOKEN\".\n",
|
||||
"# loader = MastodonTootsLoader(\n",
|
||||
"# access_token=\"<ACCESS TOKEN OF MASTODON APP>\",\n",
|
||||
"# api_base_url=\"<API BASE URL OF MASTODON APP INSTANCE>\",\n",
|
||||
"# mastodon_accounts=[\"@Gargron@mastodon.social\"],\n",
|
||||
"# number_toots=50, # Default value is 100\n",
|
||||
"# )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "05fe33b9",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<p>It is tough to leave this behind and go back to reality. And some people live here! I’m sure there are downsides but it sounds pretty good to me right now.</p>\n",
|
||||
"================================================================================\n",
|
||||
"<p>I wish we could stay here a little longer, but it is time to go home 🥲</p>\n",
|
||||
"================================================================================\n",
|
||||
"<p>Last day of the honeymoon. And it’s <a href=\"https://mastodon.social/tags/caturday\" class=\"mention hashtag\" rel=\"tag\">#<span>caturday</span></a>! This cute tabby came to the restaurant to beg for food and got some chicken.</p>\n",
|
||||
"================================================================================\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"documents = loader.load()\n",
|
||||
"for doc in documents[:3]:\n",
|
||||
" print(doc.page_content)\n",
|
||||
" print(\"=\" * 80)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "322bb6a1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The toot texts (the documents' `page_content`) is by default HTML as returned by the Mastodon API."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
133
docs/modules/models/llms/integrations/openlm.ipynb
Normal file
133
docs/modules/models/llms/integrations/openlm.ipynb
Normal file
@@ -0,0 +1,133 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# OpenLM\n",
|
||||
"[OpenLM](https://github.com/r2d4/openlm) is a zero-dependency OpenAI-compatible LLM provider that can call different inference endpoints directly via HTTP. \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"It implements the OpenAI Completion class so that it can be used as a drop-in replacement for the OpenAI API. This changeset utilizes BaseOpenAI for minimal added code.\n",
|
||||
"\n",
|
||||
"This examples goes over how to use LangChain to interact with both OpenAI and HuggingFace. You'll need API keys from both."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Setup\n",
|
||||
"Install dependencies and set API keys."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Uncomment to install openlm and openai if you haven't already\n",
|
||||
"\n",
|
||||
"# !pip install openlm\n",
|
||||
"# !pip install openai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from getpass import getpass\n",
|
||||
"import os\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Check if OPENAI_API_KEY environment variable is set\n",
|
||||
"if \"OPENAI_API_KEY\" not in os.environ:\n",
|
||||
" print(\"Enter your OpenAI API key:\")\n",
|
||||
" os.environ[\"OPENAI_API_KEY\"] = getpass()\n",
|
||||
"\n",
|
||||
"# Check if HF_API_TOKEN environment variable is set\n",
|
||||
"if \"HF_API_TOKEN\" not in os.environ:\n",
|
||||
" print(\"Enter your HuggingFace Hub API key:\")\n",
|
||||
" os.environ[\"HF_API_TOKEN\"] = getpass()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Using LangChain with OpenLM\n",
|
||||
"\n",
|
||||
"Here we're going to call two models in an LLMChain, `text-davinci-003` from OpenAI and `gpt2` on HuggingFace."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import OpenLM\n",
|
||||
"from langchain import PromptTemplate, LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: text-davinci-003\n",
|
||||
"Result: France is a country in Europe. The capital of France is Paris.\n",
|
||||
"Model: huggingface.co/gpt2\n",
|
||||
"Result: Question: What is the capital of France?\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step. I am not going to lie, this is a complicated issue, and I don't see any solutions to all this, but it is still far more\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"question = \"What is the capital of France?\"\n",
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
|
||||
"\n",
|
||||
"for model in [\"text-davinci-003\", \"huggingface.co/gpt2\"]:\n",
|
||||
" llm = OpenLM(model=model)\n",
|
||||
" llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
" result = llm_chain.run(question)\n",
|
||||
" print(\"\"\"Model: {}\n",
|
||||
"Result: {}\"\"\".format(model, result))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Agent toolkits."""
|
||||
|
||||
from langchain.agents.agent_toolkits.azure_cognitive_services.toolkit import (
|
||||
AzureCognitiveServicesToolkit,
|
||||
)
|
||||
from langchain.agents.agent_toolkits.csv.base import create_csv_agent
|
||||
from langchain.agents.agent_toolkits.file_management.toolkit import (
|
||||
FileManagementToolkit,
|
||||
@@ -60,4 +63,5 @@ __all__ = [
|
||||
"JiraToolkit",
|
||||
"FileManagementToolkit",
|
||||
"PlayWrightBrowserToolkit",
|
||||
"AzureCognitiveServicesToolkit",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
"""Azure Cognitive Services Toolkit."""
|
||||
|
||||
from langchain.agents.agent_toolkits.azure_cognitive_services.toolkit import (
|
||||
AzureCognitiveServicesToolkit,
|
||||
)
|
||||
|
||||
__all__ = ["AzureCognitiveServicesToolkit"]
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.tools.azure_cognitive_services import (
|
||||
AzureCogsFormRecognizerTool,
|
||||
AzureCogsImageAnalysisTool,
|
||||
AzureCogsSpeech2TextTool,
|
||||
AzureCogsText2SpeechTool,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
class AzureCognitiveServicesToolkit(BaseToolkit):
|
||||
"""Toolkit for Azure Cognitive Services."""
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
tools = [
|
||||
AzureCogsFormRecognizerTool(),
|
||||
AzureCogsSpeech2TextTool(),
|
||||
AzureCogsText2SpeechTool(),
|
||||
]
|
||||
|
||||
# TODO: Remove check once azure-ai-vision supports MacOS.
|
||||
if sys.platform.startswith("linux") or sys.platform.startswith("win"):
|
||||
tools.append(AzureCogsImageAnalysisTool())
|
||||
return tools
|
||||
@@ -34,7 +34,7 @@ def create_pandas_dataframe_agent(
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"pandas package not found, please install with `pip install pandas`"
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ try:
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.schema import Generation
|
||||
from langchain.vectorstores.redis import Redis as RedisVectorstore
|
||||
|
||||
@@ -178,7 +178,10 @@ class RedisSemanticCache(BaseCache):
|
||||
# TODO - implement a TTL policy in Redis
|
||||
|
||||
def __init__(
|
||||
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
|
||||
self,
|
||||
redis_url: str,
|
||||
embedding: TextEmbeddingModel,
|
||||
score_threshold: float = 0.2,
|
||||
):
|
||||
"""Initialize by passing in the `init` GPTCache func
|
||||
|
||||
@@ -313,7 +316,7 @@ class GPTCache(BaseCache):
|
||||
try:
|
||||
import gptcache # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import gptcache python package. "
|
||||
"Please install it with `pip install gptcache`."
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.callbacks.wandb_callback import WandbCallbackHandler
|
||||
from langchain.callbacks.whylabs_callback import WhyLabsCallbackHandler
|
||||
|
||||
__all__ = [
|
||||
"OpenAICallbackHandler",
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"MlflowCallbackHandler",
|
||||
"ClearMLCallbackHandler",
|
||||
"CometCallbackHandler",
|
||||
"WhyLabsCallbackHandler",
|
||||
"AsyncIteratorCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"tracing_enabled",
|
||||
|
||||
203
langchain/callbacks/whylabs_callback.py
Normal file
203
langchain/callbacks/whylabs_callback.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult
|
||||
from langchain.utils import get_from_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from whylogs.api.logger.logger import Logger
|
||||
|
||||
diagnostic_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_langkit(
|
||||
sentiment: bool = False,
|
||||
toxicity: bool = False,
|
||||
themes: bool = False,
|
||||
) -> Any:
|
||||
try:
|
||||
import langkit # noqa: F401
|
||||
import langkit.regexes # noqa: F401
|
||||
import langkit.textstat # noqa: F401
|
||||
|
||||
if sentiment:
|
||||
import langkit.sentiment # noqa: F401
|
||||
if toxicity:
|
||||
import langkit.toxicity # noqa: F401
|
||||
if themes:
|
||||
import langkit.themes # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the whylabs callback manager you need to have the `langkit` python "
|
||||
"package installed. Please install it with `pip install langkit`."
|
||||
)
|
||||
return langkit
|
||||
|
||||
|
||||
class WhyLabsCallbackHandler(BaseCallbackHandler):
|
||||
"""WhyLabs CallbackHandler."""
|
||||
|
||||
def __init__(self, logger: Logger):
|
||||
"""Initiate the rolling logger"""
|
||||
super().__init__()
|
||||
self.logger = logger
|
||||
diagnostic_logger.info(
|
||||
"Initialized WhyLabs callback handler with configured whylogs Logger."
|
||||
)
|
||||
|
||||
def _profile_generations(self, generations: List[Generation]) -> None:
|
||||
for gen in generations:
|
||||
self.logger.log({"response": gen.text})
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Pass the input prompts to the logger"""
|
||||
for prompt in prompts:
|
||||
self.logger.log({"prompt": prompt})
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Pass the generated response to the logger."""
|
||||
for generations in response.generations:
|
||||
self._profile_generations(generations)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Do nothing."""
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
pass
|
||||
|
||||
def flush(self) -> None:
|
||||
self.logger._do_rollover()
|
||||
diagnostic_logger.info("Flushing WhyLabs logger, writing profile...")
|
||||
|
||||
def close(self) -> None:
|
||||
self.logger.close()
|
||||
diagnostic_logger.info("Closing WhyLabs logger, see you next time!")
|
||||
|
||||
def __enter__(self) -> WhyLabsCallbackHandler:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self, exception_type: Any, exception_value: Any, traceback: Any
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
sentiment: bool = False,
|
||||
toxicity: bool = False,
|
||||
themes: bool = False,
|
||||
) -> Logger:
|
||||
"""Instantiate whylogs Logger from params.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): WhyLabs API key. Optional because the preferred
|
||||
way to specify the API key is with environment variable
|
||||
WHYLABS_API_KEY.
|
||||
org_id (Optional[str]): WhyLabs organization id to write profiles to.
|
||||
If not set must be specified in environment variable
|
||||
WHYLABS_DEFAULT_ORG_ID.
|
||||
dataset_id (Optional[str]): The model or dataset this callback is gathering
|
||||
telemetry for. If not set must be specified in environment variable
|
||||
WHYLABS_DEFAULT_DATASET_ID.
|
||||
sentiment (bool): If True will initialize a model to perform
|
||||
sentiment analysis compound score. Defaults to False and will not gather
|
||||
this metric.
|
||||
toxicity (bool): If True will initialize a model to score
|
||||
toxicity. Defaults to False and will not gather this metric.
|
||||
themes (bool): If True will initialize a model to calculate
|
||||
distance to configured themes. Defaults to None and will not gather this
|
||||
metric.
|
||||
"""
|
||||
# langkit library will import necessary whylogs libraries
|
||||
import_langkit(sentiment=sentiment, toxicity=toxicity, themes=themes)
|
||||
|
||||
import whylogs as why
|
||||
from whylogs.api.writer.whylabs import WhyLabsWriter
|
||||
from whylogs.core.schema import DeclarativeSchema
|
||||
from whylogs.experimental.core.metrics.udf_metric import generate_udf_schema
|
||||
|
||||
api_key = api_key or get_from_env("api_key", "WHYLABS_API_KEY")
|
||||
org_id = org_id or get_from_env("org_id", "WHYLABS_DEFAULT_ORG_ID")
|
||||
dataset_id = dataset_id or get_from_env(
|
||||
"dataset_id", "WHYLABS_DEFAULT_DATASET_ID"
|
||||
)
|
||||
whylabs_writer = WhyLabsWriter(
|
||||
api_key=api_key, org_id=org_id, dataset_id=dataset_id
|
||||
)
|
||||
|
||||
langkit_schema = DeclarativeSchema(generate_udf_schema())
|
||||
whylabs_logger = why.logger(
|
||||
mode="rolling", interval=5, when="M", schema=langkit_schema
|
||||
)
|
||||
|
||||
whylabs_logger.append_writer(writer=whylabs_writer)
|
||||
diagnostic_logger.info(
|
||||
"Started whylogs Logger with WhyLabsWriter and initialized LangKit. 📝"
|
||||
)
|
||||
return cls(whylabs_logger)
|
||||
@@ -14,16 +14,16 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
|
||||
|
||||
class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
class HypotheticalDocumentEmbedder(Chain, TextEmbeddingModel):
|
||||
"""Generate hypothetical document for query, and then embed that.
|
||||
|
||||
Based on https://arxiv.org/abs/2212.10496
|
||||
"""
|
||||
|
||||
base_embeddings: Embeddings
|
||||
base_embeddings: TextEmbeddingModel
|
||||
llm_chain: LLMChain
|
||||
|
||||
class Config:
|
||||
@@ -42,9 +42,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
"""Output keys for Hyde's LLM chain."""
|
||||
return self.llm_chain.output_keys
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call the base embeddings."""
|
||||
return self.base_embeddings.embed_texts(texts)
|
||||
return self.base_embeddings.embed_documents(texts)
|
||||
|
||||
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
|
||||
"""Combine embeddings into final embeddings."""
|
||||
@@ -55,7 +55,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
var_name = self.llm_chain.input_keys[0]
|
||||
result = self.llm_chain.generate([{var_name: text}])
|
||||
documents = [generation.text for generation in result.generations[0]]
|
||||
embeddings = self.embed_texts(documents)
|
||||
embeddings = self.embed_documents(documents)
|
||||
return self.combine_embeddings(embeddings)
|
||||
|
||||
def _call(
|
||||
@@ -71,7 +71,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
base_embeddings: Embeddings,
|
||||
base_embeddings: TextEmbeddingModel,
|
||||
prompt_key: str,
|
||||
**kwargs: Any,
|
||||
) -> HypotheticalDocumentEmbedder:
|
||||
|
||||
@@ -54,7 +54,7 @@ class OpenAIModerationChain(Chain):
|
||||
openai.organization = openai_organization
|
||||
values["client"] = openai.Moderation
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import Extra
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.router.base import RouterChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class EmbeddingRouterChain(RouterChain):
|
||||
cls,
|
||||
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
embeddings: Embeddings,
|
||||
embeddings: TextEmbeddingModel,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingRouterChain:
|
||||
"""Convenience constructor."""
|
||||
|
||||
@@ -86,7 +86,7 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
if openai_organization:
|
||||
openai.organization = openai_organization
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
|
||||
@@ -342,7 +342,7 @@ class LangChainPlusClient(BaseSettings):
|
||||
) -> Example:
|
||||
"""Create a dataset example in the LangChain+ API."""
|
||||
if dataset_id is None:
|
||||
dataset_id = self.read_dataset(dataset_name).id
|
||||
dataset_id = self.read_dataset(dataset_name=dataset_name).id
|
||||
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
|
||||
@@ -48,6 +48,7 @@ from langchain.document_loaders.image_captions import ImageCaptionLoader
|
||||
from langchain.document_loaders.imsdb import IMSDbLoader
|
||||
from langchain.document_loaders.json_loader import JSONLoader
|
||||
from langchain.document_loaders.markdown import UnstructuredMarkdownLoader
|
||||
from langchain.document_loaders.mastodon import MastodonTootsLoader
|
||||
from langchain.document_loaders.mediawikidump import MWDumpLoader
|
||||
from langchain.document_loaders.modern_treasury import ModernTreasuryLoader
|
||||
from langchain.document_loaders.notebook import NotebookLoader
|
||||
@@ -160,6 +161,7 @@ __all__ = [
|
||||
"ImageCaptionLoader",
|
||||
"JSONLoader",
|
||||
"MWDumpLoader",
|
||||
"MastodonTootsLoader",
|
||||
"MathpixPDFLoader",
|
||||
"ModernTreasuryLoader",
|
||||
"NotebookLoader",
|
||||
|
||||
@@ -41,7 +41,7 @@ class ApifyDatasetLoader(BaseLoader, BaseModel):
|
||||
|
||||
values["apify_client"] = ApifyClient()
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import apify-client Python package. "
|
||||
"Please install it with `pip install apify-client`."
|
||||
)
|
||||
|
||||
@@ -63,7 +63,7 @@ class DocugamiLoader(BaseLoader, BaseModel):
|
||||
try:
|
||||
from lxml import etree
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import lxml python package. "
|
||||
"Please install it with `pip install lxml`."
|
||||
)
|
||||
@@ -259,7 +259,7 @@ class DocugamiLoader(BaseLoader, BaseModel):
|
||||
try:
|
||||
from lxml import etree
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import lxml python package. "
|
||||
"Please install it with `pip install lxml`."
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ class DuckDBLoader(BaseLoader):
|
||||
try:
|
||||
import duckdb
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import duckdb python package. "
|
||||
"Please install it with `pip install duckdb`."
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ class ImageCaptionLoader(BaseLoader):
|
||||
try:
|
||||
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"`transformers` package not found, please install with "
|
||||
"`pip install transformers`."
|
||||
)
|
||||
@@ -66,7 +66,7 @@ class ImageCaptionLoader(BaseLoader):
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"`PIL` package not found, please install with `pip install pillow`"
|
||||
)
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class JSONLoader(BaseLoader):
|
||||
try:
|
||||
import jq # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"jq package not found, please install it with `pip install jq`"
|
||||
)
|
||||
|
||||
|
||||
88
langchain/document_loaders/mastodon.py
Normal file
88
langchain/document_loaders/mastodon.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Mastodon document loader."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import mastodon
|
||||
|
||||
|
||||
def _dependable_mastodon_import() -> mastodon:
|
||||
try:
|
||||
import mastodon
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Mastodon.py package not found, "
|
||||
"please install it with `pip install Mastodon.py`"
|
||||
)
|
||||
return mastodon
|
||||
|
||||
|
||||
class MastodonTootsLoader(BaseLoader):
|
||||
"""Mastodon toots loader."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mastodon_accounts: Sequence[str],
|
||||
number_toots: Optional[int] = 100,
|
||||
exclude_replies: bool = False,
|
||||
access_token: Optional[str] = None,
|
||||
api_base_url: str = "https://mastodon.social",
|
||||
):
|
||||
"""Instantiate Mastodon toots loader.
|
||||
|
||||
Args:
|
||||
mastodon_accounts: The list of Mastodon accounts to query.
|
||||
number_toots: How many toots to pull for each account.
|
||||
exclude_replies: Whether to exclude reply toots from the load.
|
||||
access_token: An access token if toots are loaded as a Mastodon app. Can
|
||||
also be specified via the environment variables "MASTODON_ACCESS_TOKEN".
|
||||
api_base_url: A Mastodon API base URL to talk to, if not using the default.
|
||||
"""
|
||||
mastodon = _dependable_mastodon_import()
|
||||
access_token = access_token or os.environ.get("MASTODON_ACCESS_TOKEN")
|
||||
self.api = mastodon.Mastodon(
|
||||
access_token=access_token, api_base_url=api_base_url
|
||||
)
|
||||
self.mastodon_accounts = mastodon_accounts
|
||||
self.number_toots = number_toots
|
||||
self.exclude_replies = exclude_replies
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load toots into documents."""
|
||||
results: List[Document] = []
|
||||
for account in self.mastodon_accounts:
|
||||
user = self.api.account_lookup(account)
|
||||
toots = self.api.account_statuses(
|
||||
user.id,
|
||||
only_media=False,
|
||||
pinned=False,
|
||||
exclude_replies=self.exclude_replies,
|
||||
exclude_reblogs=True,
|
||||
limit=self.number_toots,
|
||||
)
|
||||
docs = self._format_toots(toots, user)
|
||||
results.extend(docs)
|
||||
return results
|
||||
|
||||
def _format_toots(
|
||||
self, toots: List[Dict[str, Any]], user_info: dict
|
||||
) -> Iterable[Document]:
|
||||
"""Format toots into documents.
|
||||
|
||||
Adding user info, and selected toot fields into the metadata.
|
||||
"""
|
||||
for toot in toots:
|
||||
metadata = {
|
||||
"created_at": toot["created_at"],
|
||||
"user_info": user_info,
|
||||
"is_reply": toot["in_reply_to_id"] is not None,
|
||||
}
|
||||
yield Document(
|
||||
page_content=toot["content"],
|
||||
metadata=metadata,
|
||||
)
|
||||
@@ -83,7 +83,7 @@ class NotebookLoader(BaseLoader):
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"pandas is needed for Notebook Loader, "
|
||||
"please install with `pip install pandas`"
|
||||
)
|
||||
|
||||
@@ -77,7 +77,7 @@ class OneDriveLoader(BaseLoader, BaseModel):
|
||||
try:
|
||||
from O365 import FileSystemTokenBackend
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"O365 package not found, please install it with `pip install o365`"
|
||||
)
|
||||
if self.auth_with_token:
|
||||
|
||||
@@ -103,7 +103,7 @@ class PyPDFLoader(BasePDFLoader):
|
||||
try:
|
||||
import pypdf # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"pypdf package not found, please install it with " "`pip install pypdf`"
|
||||
)
|
||||
self.parser = PyPDFParser()
|
||||
@@ -194,7 +194,7 @@ class PDFMinerLoader(BasePDFLoader):
|
||||
try:
|
||||
from pdfminer.high_level import extract_text # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"`pdfminer` package not found, please install it with "
|
||||
"`pip install pdfminer.six`"
|
||||
)
|
||||
@@ -222,7 +222,7 @@ class PDFMinerPDFasHTMLLoader(BasePDFLoader):
|
||||
try:
|
||||
from pdfminer.high_level import extract_text_to_fp # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"`pdfminer` package not found, please install it with "
|
||||
"`pip install pdfminer.six`"
|
||||
)
|
||||
@@ -256,7 +256,7 @@ class PyMuPDFLoader(BasePDFLoader):
|
||||
try:
|
||||
import fitz # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"`PyMuPDF` package not found, please install it with "
|
||||
"`pip install pymupdf`"
|
||||
)
|
||||
@@ -375,7 +375,7 @@ class PDFPlumberLoader(BasePDFLoader):
|
||||
try:
|
||||
import pdfplumber # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"pdfplumber package not found, please install it with "
|
||||
"`pip install pdfplumber`"
|
||||
)
|
||||
|
||||
@@ -19,9 +19,8 @@ class ReadTheDocsLoader(BaseLoader):
|
||||
"""Initialize path."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import python packages. "
|
||||
"Please install it with `pip install beautifulsoup4`. "
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ class S3DirectoryLoader(BaseLoader):
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
|
||||
@@ -21,7 +21,7 @@ class S3FileLoader(BaseLoader):
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import `boto3` python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ class SitemapLoader(WebBaseLoader):
|
||||
try:
|
||||
import lxml # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"lxml package not found, please install it with " "`pip install lxml`"
|
||||
)
|
||||
|
||||
@@ -107,8 +107,9 @@ class SitemapLoader(WebBaseLoader):
|
||||
try:
|
||||
import bs4
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"bs4 package not found, please install it with " "`pip install bs4`"
|
||||
raise ImportError(
|
||||
"beautifulsoup4 package not found, please install it"
|
||||
" with `pip install beautifulsoup4`"
|
||||
)
|
||||
fp = open(self.web_path)
|
||||
soup = bs4.BeautifulSoup(fp, "xml")
|
||||
|
||||
@@ -13,8 +13,8 @@ class SRTLoader(BaseLoader):
|
||||
try:
|
||||
import pysrt # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"package `pysrt` not found, please install it with `pysrt`"
|
||||
raise ImportError(
|
||||
"package `pysrt` not found, please install it with `pip install pysrt`"
|
||||
)
|
||||
self.file_path = file_path
|
||||
|
||||
|
||||
@@ -226,7 +226,7 @@ class TelegramChatApiLoader(BaseLoader):
|
||||
nest_asyncio.apply()
|
||||
asyncio.run(self.fetch_data_from_telegram())
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"""`nest_asyncio` package not found.
|
||||
please install with `pip install nest_asyncio`
|
||||
"""
|
||||
@@ -239,7 +239,7 @@ class TelegramChatApiLoader(BaseLoader):
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"""`pandas` package not found.
|
||||
please install with `pip install pandas`
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,7 @@ def _dependable_tweepy_import() -> tweepy:
|
||||
try:
|
||||
import tweepy
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"tweepy package not found, please install it with `pip install tweepy`"
|
||||
)
|
||||
return tweepy
|
||||
|
||||
@@ -30,7 +30,7 @@ class PlaywrightURLLoader(BaseLoader):
|
||||
try:
|
||||
import playwright # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"playwright package not found, please install it with "
|
||||
"`pip install playwright`"
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ class SeleniumURLLoader(BaseLoader):
|
||||
try:
|
||||
import selenium # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"selenium package not found, please install it with "
|
||||
"`pip install selenium`"
|
||||
)
|
||||
@@ -48,7 +48,7 @@ class SeleniumURLLoader(BaseLoader):
|
||||
try:
|
||||
import unstructured # noqa:F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"unstructured package not found, please install it with "
|
||||
"`pip install unstructured`"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Callable, List, Sequence
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.schema import BaseDocumentTransformer, Document
|
||||
|
||||
@@ -50,12 +50,14 @@ def _filter_similar_embeddings(
|
||||
|
||||
|
||||
def _get_embeddings_from_stateful_docs(
|
||||
embeddings: Embeddings, documents: Sequence[_DocumentWithState]
|
||||
embeddings: TextEmbeddingModel, documents: Sequence[_DocumentWithState]
|
||||
) -> List[List[float]]:
|
||||
if len(documents) and "embedded_doc" in documents[0].state:
|
||||
embedded_documents = [doc.state["embedded_doc"] for doc in documents]
|
||||
else:
|
||||
embedded_documents = embeddings.embed_texts([d.page_content for d in documents])
|
||||
embedded_documents = embeddings.embed_documents(
|
||||
[d.page_content for d in documents]
|
||||
)
|
||||
for doc, embedding in zip(documents, embedded_documents):
|
||||
doc.state["embedded_doc"] = embedding
|
||||
return embedded_documents
|
||||
@@ -64,7 +66,7 @@ def _get_embeddings_from_stateful_docs(
|
||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||
"""Filter that drops redundant documents by comparing their embeddings."""
|
||||
|
||||
embeddings: Embeddings
|
||||
embeddings: TextEmbeddingModel
|
||||
"""Embeddings to use for embedding document contents."""
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
"""Similarity function for comparing documents. Function expected to take as input
|
||||
|
||||
@@ -2,11 +2,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
|
||||
class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, TextEmbeddingModel):
|
||||
"""
|
||||
Wrapper for Aleph Alpha's Asymmetric Embeddings
|
||||
AA provides you with an endpoint to embed a document and a query.
|
||||
@@ -65,7 +65,7 @@ class AlephAlphaAsymmetricSemanticEmbedding(BaseModel, Embeddings):
|
||||
values["client"] = Client(token=aleph_alpha_api_key)
|
||||
return values
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Aleph Alpha's asymmetric Document endpoint.
|
||||
|
||||
Args:
|
||||
@@ -186,7 +186,7 @@ class AlephAlphaSymmetricSemanticEmbedding(AlephAlphaAsymmetricSemanticEmbedding
|
||||
|
||||
return query_response.embedding
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Aleph Alpha's Document endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,17 +3,17 @@ from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""DEPRECATED. Kept for backwards compatibility."""
|
||||
return self.embed_texts(texts)
|
||||
class TextEmbeddingModel(ABC):
|
||||
"""Interface for text embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search texts."""
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
|
||||
|
||||
# For backwards compatibility.
|
||||
Embedding = TextEmbeddingModel
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class CohereEmbeddings(BaseModel, Embeddings):
|
||||
class CohereEmbeddings(BaseModel, TextEmbeddingModel):
|
||||
"""Wrapper around Cohere embedding models.
|
||||
|
||||
To use, you should have the ``cohere`` python package installed, and the
|
||||
@@ -48,13 +48,13 @@ class CohereEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
values["client"] = cohere.Client(cohere_api_key)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import cohere python package. "
|
||||
"Please install it with `pip install cohere`."
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Cohere's embedding endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,16 +3,16 @@ from typing import List
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings, BaseModel):
|
||||
class FakeEmbeddings(TextEmbeddingModel, BaseModel):
|
||||
size: int
|
||||
|
||||
def _get_embedding(self) -> List[float]:
|
||||
return list(np.random.normal(size=self.size))
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self._get_embedding() for _ in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
|
||||
@@ -13,7 +13,7 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -54,7 +54,7 @@ def embed_with_retry(
|
||||
return _embed_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||
class GooglePalmEmbeddings(BaseModel, TextEmbeddingModel):
|
||||
client: Any
|
||||
google_api_key: Optional[str]
|
||||
model_name: str = "models/embedding-gecko-001"
|
||||
@@ -77,7 +77,7 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return values
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
|
||||
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
||||
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
|
||||
@@ -13,7 +13,7 @@ DEFAULT_QUERY_INSTRUCTION = (
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
class HuggingFaceEmbeddings(BaseModel, TextEmbeddingModel):
|
||||
"""Wrapper around sentence_transformers embedding models.
|
||||
|
||||
To use, you should have the ``sentence_transformers`` python package installed.
|
||||
@@ -46,7 +46,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
import sentence_transformers
|
||||
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import sentence_transformers python package. "
|
||||
"Please install it with `pip install sentence_transformers`."
|
||||
) from exc
|
||||
@@ -60,7 +60,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
@@ -87,7 +87,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
return embedding.tolist()
|
||||
|
||||
|
||||
class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
class HuggingFaceInstructEmbeddings(BaseModel, TextEmbeddingModel):
|
||||
"""Wrapper around sentence_transformers embedding models.
|
||||
|
||||
To use, you should have the ``sentence_transformers``
|
||||
@@ -135,7 +135,7 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,14 +3,14 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_REPO_ID = "sentence-transformers/all-mpnet-base-v2"
|
||||
VALID_TASKS = ("feature-extraction",)
|
||||
|
||||
|
||||
class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
||||
class HuggingFaceHubEmbeddings(BaseModel, TextEmbeddingModel):
|
||||
"""Wrapper around HuggingFaceHub embedding models.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
@@ -77,7 +77,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
@@ -101,5 +101,5 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
response = self.embed_texts([text])[0]
|
||||
response = self.embed_documents([text])[0]
|
||||
return response
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional
|
||||
import requests
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class JinaEmbeddings(BaseModel, Embeddings):
|
||||
class JinaEmbeddings(BaseModel, TextEmbeddingModel):
|
||||
client: Any #: :meta private:
|
||||
|
||||
model_name: str = "ViT-B-32::openai"
|
||||
@@ -34,7 +34,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
try:
|
||||
import jina
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import `jina` python package. "
|
||||
"Please install it with `pip install jina`."
|
||||
)
|
||||
@@ -71,7 +71,7 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
payload = dict(inputs=docs, metadata=self.request_headers, **kwargs)
|
||||
return self.client.post(on="/encode", **payload)
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
|
||||
|
||||
class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
class LlamaCppEmbeddings(BaseModel, TextEmbeddingModel):
|
||||
"""Wrapper around llama.cpp embedding models.
|
||||
|
||||
To use, you should have the llama-cpp-python library installed, and provide the
|
||||
@@ -99,7 +99,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return values
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents using the Llama model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -25,7 +25,7 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,7 +64,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||
return _embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
class OpenAIEmbeddings(BaseModel, TextEmbeddingModel):
|
||||
"""Wrapper around OpenAI embedding models.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
@@ -178,7 +178,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
openai.api_type = openai_api_type
|
||||
values["client"] = openai.Embedding
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
@@ -192,67 +192,64 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
|
||||
try:
|
||||
import tiktoken
|
||||
|
||||
tokens = []
|
||||
indices = []
|
||||
encoding = tiktoken.model.encoding_for_model(self.model)
|
||||
for i, text in enumerate(texts):
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
token = encoding.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
for j in range(0, len(token), self.embedding_ctx_length):
|
||||
tokens += [token[j : j + self.embedding_ctx_length]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
for i in range(0, len(tokens), _chunk_size):
|
||||
response = embed_with_retry(
|
||||
self,
|
||||
input=tokens[i : i + _chunk_size],
|
||||
engine=self.deployment,
|
||||
request_timeout=self.request_timeout,
|
||||
headers=self.headers,
|
||||
)
|
||||
batched_embeddings += [r["embedding"] for r in response["data"]]
|
||||
|
||||
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
average = embed_with_retry(
|
||||
self,
|
||||
input="",
|
||||
engine=self.deployment,
|
||||
request_timeout=self.request_timeout,
|
||||
headers=self.headers,
|
||||
)["data"][0]["embedding"]
|
||||
else:
|
||||
average = np.average(
|
||||
_result, axis=0, weights=num_tokens_in_batch[i]
|
||||
)
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
return embeddings
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to for OpenAIEmbeddings. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
tokens = []
|
||||
indices = []
|
||||
encoding = tiktoken.model.encoding_for_model(self.model)
|
||||
for i, text in enumerate(texts):
|
||||
if self.model.endswith("001"):
|
||||
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text = text.replace("\n", " ")
|
||||
token = encoding.encode(
|
||||
text,
|
||||
allowed_special=self.allowed_special,
|
||||
disallowed_special=self.disallowed_special,
|
||||
)
|
||||
for j in range(0, len(token), self.embedding_ctx_length):
|
||||
tokens += [token[j : j + self.embedding_ctx_length]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
for i in range(0, len(tokens), _chunk_size):
|
||||
response = embed_with_retry(
|
||||
self,
|
||||
input=tokens[i : i + _chunk_size],
|
||||
engine=self.deployment,
|
||||
request_timeout=self.request_timeout,
|
||||
headers=self.headers,
|
||||
)
|
||||
batched_embeddings += [r["embedding"] for r in response["data"]]
|
||||
|
||||
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
average = embed_with_retry(
|
||||
self,
|
||||
input="",
|
||||
engine=self.deployment,
|
||||
request_timeout=self.request_timeout,
|
||||
headers=self.headers,
|
||||
)["data"][0]["embedding"]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
return embeddings
|
||||
|
||||
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||
"""Call out to OpenAI's embedding endpoint."""
|
||||
# handle large input text
|
||||
@@ -271,7 +268,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
headers=self.headers,
|
||||
)["data"][0]["embedding"]
|
||||
|
||||
def embed_texts(
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||
) -> List[List[float]]:
|
||||
"""Call out to OpenAI's embedding endpoint for embedding search docs.
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ class EmbeddingsContentHandler(ContentHandlerBase[List[str], List[List[float]]])
|
||||
"""Content handler for LLM class."""
|
||||
|
||||
|
||||
class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
class SagemakerEndpointEmbeddings(BaseModel, TextEmbeddingModel):
|
||||
"""Wrapper around custom Sagemaker Inference Endpoints.
|
||||
|
||||
To use, you must supply the endpoint name from your deployed
|
||||
@@ -164,7 +164,9 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
return self.content_handler.transform_output(response["Body"])
|
||||
|
||||
def embed_texts(self, texts: List[str], chunk_size: int = 64) -> List[List[float]]:
|
||||
def embed_documents(
|
||||
self, texts: List[str], chunk_size: int = 64
|
||||
) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a SageMaker Inference Endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Callable, List
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.llms import SelfHostedPipeline
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ def _embed_documents(pipeline: Any, *args: Any, **kwargs: Any) -> List[List[floa
|
||||
return pipeline(*args, **kwargs)
|
||||
|
||||
|
||||
class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings):
|
||||
class SelfHostedEmbeddings(SelfHostedPipeline, TextEmbeddingModel):
|
||||
"""Runs custom embedding models on self-hosted remote hardware.
|
||||
|
||||
Supported hardware includes auto-launched instances on AWS, GCP, Azure,
|
||||
@@ -72,7 +72,7 @@ class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -140,7 +140,7 @@ class SelfHostedHuggingFaceInstructEmbeddings(SelfHostedHuggingFaceEmbeddings):
|
||||
load_fn_kwargs["device"] = load_fn_kwargs.get("device", 0)
|
||||
super().__init__(load_fn_kwargs=load_fn_kwargs, **kwargs)
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace instruct model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
|
||||
DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
|
||||
|
||||
|
||||
class TensorflowHubEmbeddings(BaseModel, Embeddings):
|
||||
class TensorflowHubEmbeddings(BaseModel, TextEmbeddingModel):
|
||||
"""Wrapper around tensorflow_hub embedding models.
|
||||
|
||||
To use, you should have the ``tensorflow_text`` python package installed.
|
||||
@@ -30,20 +30,27 @@ class TensorflowHubEmbeddings(BaseModel, Embeddings):
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import tensorflow_hub
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tensorflow-hub python package. "
|
||||
"Please install it with `pip install tensorflow-hub``."
|
||||
)
|
||||
try:
|
||||
import tensorflow_text # noqa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tensorflow_text python package. "
|
||||
"Please install it with `pip install tensorflow_text``."
|
||||
)
|
||||
|
||||
self.embed = tensorflow_hub.load(self.model_url)
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import some python packages." "Please install them."
|
||||
) from e
|
||||
self.embed = tensorflow_hub.load(self.model_url)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a TensorflowHub embedding model.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -54,7 +54,7 @@ class NetworkxEntityGraph:
|
||||
try:
|
||||
import networkx as nx
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import networkx python package. "
|
||||
"Please install it with `pip install networkx`."
|
||||
)
|
||||
@@ -70,7 +70,7 @@ class NetworkxEntityGraph:
|
||||
try:
|
||||
import networkx as nx
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import networkx python package. "
|
||||
"Please install it with `pip install networkx`."
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
|
||||
from langchain.chains.retrieval_qa.base import RetrievalQA
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.schema import Document
|
||||
@@ -55,7 +55,7 @@ class VectorstoreIndexCreator(BaseModel):
|
||||
"""Logic for creating indexes."""
|
||||
|
||||
vectorstore_cls: Type[VectorStore] = Chroma
|
||||
embedding: Embeddings = Field(default_factory=OpenAIEmbeddings)
|
||||
embedding: TextEmbeddingModel = Field(default_factory=OpenAIEmbeddings)
|
||||
text_splitter: TextSplitter = Field(default_factory=_get_default_text_splitter)
|
||||
vectorstore_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from langchain.llms.llamacpp import LlamaCpp
|
||||
from langchain.llms.modal import Modal
|
||||
from langchain.llms.nlpcloud import NLPCloud
|
||||
from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat
|
||||
from langchain.llms.openlm import OpenLM
|
||||
from langchain.llms.petals import Petals
|
||||
from langchain.llms.pipelineai import PipelineAI
|
||||
from langchain.llms.predictionguard import PredictionGuard
|
||||
@@ -53,6 +54,7 @@ __all__ = [
|
||||
"NLPCloud",
|
||||
"OpenAI",
|
||||
"OpenAIChat",
|
||||
"OpenLM",
|
||||
"Petals",
|
||||
"PipelineAI",
|
||||
"HuggingFaceEndpoint",
|
||||
@@ -96,6 +98,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"nlpcloud": NLPCloud,
|
||||
"human-input": HumanInputLLM,
|
||||
"openai": OpenAI,
|
||||
"openlm": OpenLM,
|
||||
"petals": Petals,
|
||||
"pipelineai": PipelineAI,
|
||||
"huggingface_pipeline": HuggingFacePipeline,
|
||||
|
||||
@@ -148,7 +148,7 @@ class AlephAlpha(LLM):
|
||||
|
||||
values["client"] = aleph_alpha_client.Client(token=aleph_alpha_api_key)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import aleph_alpha_client python package. "
|
||||
"Please install it with `pip install aleph_alpha_client`."
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ class _AnthropicCommon(BaseModel):
|
||||
values["AI_PROMPT"] = anthropic.AI_PROMPT
|
||||
values["count_tokens"] = anthropic.count_tokens
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import anthropic python package. "
|
||||
"Please it install it with `pip install anthropic`."
|
||||
)
|
||||
|
||||
@@ -91,7 +91,7 @@ class Banana(LLM):
|
||||
try:
|
||||
import banana_dev as banana
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import banana-dev python package. "
|
||||
"Please install it with `pip install banana-dev`."
|
||||
)
|
||||
|
||||
@@ -72,7 +72,7 @@ class Cohere(LLM):
|
||||
|
||||
values["client"] = cohere.Client(cohere_api_key)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import cohere python package. "
|
||||
"Please install it with `pip install cohere`."
|
||||
)
|
||||
|
||||
@@ -29,7 +29,10 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
try:
|
||||
import google.api_core.exceptions
|
||||
except ImportError:
|
||||
raise ImportError()
|
||||
raise ImportError(
|
||||
"Could not import google-api-core python package. "
|
||||
"Please install it with `pip install google-api-core`."
|
||||
)
|
||||
|
||||
multiplier = 2
|
||||
min_seconds = 1
|
||||
@@ -105,7 +108,10 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
except ImportError:
|
||||
raise ImportError("Could not import google.generativeai python package.")
|
||||
raise ImportError(
|
||||
"Could not import google-generativeai python package. "
|
||||
"Please install it with `pip install google-generativeai`."
|
||||
)
|
||||
|
||||
values["client"] = genai
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ class GooseAI(LLM):
|
||||
openai.api_base = "https://api.goose.ai/v1"
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
|
||||
@@ -97,7 +97,7 @@ class HuggingFaceTextGenInference(LLM):
|
||||
values["inference_server_url"], timeout=values["timeout"]
|
||||
)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import text_generation python package. "
|
||||
"Please install it with `pip install text_generation`."
|
||||
)
|
||||
|
||||
@@ -75,7 +75,7 @@ class NLPCloud(LLM):
|
||||
values["model_name"], nlpcloud_api_key, gpu=True, lang="en"
|
||||
)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import nlpcloud python package. "
|
||||
"Please install it with `pip install nlpcloud`."
|
||||
)
|
||||
|
||||
@@ -234,7 +234,7 @@ class BaseOpenAI(BaseLLM):
|
||||
openai.organization = openai_organization
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
@@ -462,7 +462,7 @@ class BaseOpenAI(BaseLLM):
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
@@ -677,7 +677,7 @@ class OpenAIChat(BaseLLM):
|
||||
if openai_organization:
|
||||
openai.organization = openai_organization
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
@@ -807,7 +807,7 @@ class OpenAIChat(BaseLLM):
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate get_num_tokens. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
|
||||
26
langchain/llms/openlm.py
Normal file
26
langchain/llms/openlm.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.llms.openai import BaseOpenAI
|
||||
|
||||
|
||||
class OpenLM(BaseOpenAI):
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**{"model": self.model_name}, **super()._invocation_params}
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
try:
|
||||
import openlm
|
||||
|
||||
values["client"] = openlm.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openlm python package. "
|
||||
"Please install it with `pip install openlm`."
|
||||
)
|
||||
if values["streaming"]:
|
||||
raise ValueError("Streaming not supported with openlm")
|
||||
return values
|
||||
@@ -50,7 +50,7 @@ class PredictionGuard(LLM):
|
||||
|
||||
values["client"] = pg.Client(token=token)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import predictionguard python package. "
|
||||
"Please install it with `pip install predictionguard`."
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ class Replicate(LLM):
|
||||
try:
|
||||
import replicate as replicate_python
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import replicate python package. "
|
||||
"Please install it with `pip install replicate`."
|
||||
)
|
||||
|
||||
@@ -103,7 +103,7 @@ class RWKV(LLM, BaseModel):
|
||||
try:
|
||||
import tokenizers
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import tokenizers python package. "
|
||||
"Please install it with `pip install tokenizers`."
|
||||
)
|
||||
|
||||
@@ -182,7 +182,7 @@ class SagemakerEndpoint(LLM):
|
||||
) from e
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import boto3 python package. "
|
||||
"Please install it with `pip install boto3`."
|
||||
)
|
||||
|
||||
@@ -155,7 +155,7 @@ class SelfHostedPipeline(LLM):
|
||||
import runhouse as rh
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import runhouse python package. "
|
||||
"Please install it with `pip install runhouse`."
|
||||
)
|
||||
|
||||
@@ -52,13 +52,11 @@ class FirestoreChatMessageHistory(BaseChatMessageHistory):
|
||||
try:
|
||||
import firebase_admin
|
||||
from firebase_admin import firestore
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
"Failed to import Firebase and Firestore: %s. "
|
||||
"Make sure to install the 'firebase-admin' module.",
|
||||
e,
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import firebase-admin python package. "
|
||||
"Please install it with `pip install firebase-admin`."
|
||||
)
|
||||
raise e
|
||||
|
||||
# For multiple instances, only initialize the app once.
|
||||
try:
|
||||
|
||||
@@ -25,7 +25,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
@@ -91,7 +91,7 @@ class RedisEntityStore(BaseEntityStore):
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
@@ -64,7 +64,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
embeddings: Embeddings,
|
||||
embeddings: TextEmbeddingModel,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
@@ -130,7 +130,7 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):
|
||||
def from_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
embeddings: Embeddings,
|
||||
embeddings: TextEmbeddingModel,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
|
||||
@@ -41,7 +41,7 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
|
||||
values["client"] = cohere.Client(cohere_api_key)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import cohere python package. "
|
||||
"Please install it with `pip install cohere`."
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain.document_transformers import (
|
||||
_get_embeddings_from_stateful_docs,
|
||||
get_stateful_documents,
|
||||
)
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.math_utils import cosine_similarity
|
||||
from langchain.retrievers.document_compressors.base import (
|
||||
BaseDocumentCompressor,
|
||||
@@ -17,7 +17,7 @@ from langchain.schema import Document
|
||||
|
||||
|
||||
class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
embeddings: Embeddings
|
||||
embeddings: TextEmbeddingModel
|
||||
"""Embeddings to use for embedding document contents and queries."""
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
"""Similarity function for comparing documents. Function expected to take as input
|
||||
|
||||
@@ -10,17 +10,17 @@ from typing import Any, List, Optional
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
def create_index(contexts: List[str], embeddings: TextEmbeddingModel) -> np.ndarray:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
||||
|
||||
|
||||
class KNNRetriever(BaseRetriever, BaseModel):
|
||||
embeddings: Embeddings
|
||||
embeddings: TextEmbeddingModel
|
||||
index: Any
|
||||
texts: List[str]
|
||||
k: int = 4
|
||||
@@ -34,7 +34,7 @@ class KNNRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls, texts: List[str], embeddings: Embeddings, **kwargs: Any
|
||||
cls, texts: List[str], embeddings: TextEmbeddingModel, **kwargs: Any
|
||||
) -> KNNRetriever:
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Milvus Retriever"""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.milvus import Milvus
|
||||
|
||||
@@ -11,7 +11,7 @@ from langchain.vectorstores.milvus import Milvus
|
||||
class MilvusRetreiver(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Embeddings,
|
||||
embedding_function: TextEmbeddingModel,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: Optional[Dict[str, Any]] = None,
|
||||
consistency_level: str = "Session",
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ def hash_text(text: str) -> str:
|
||||
def create_index(
|
||||
contexts: List[str],
|
||||
index: Any,
|
||||
embeddings: Embeddings,
|
||||
embeddings: TextEmbeddingModel,
|
||||
sparse_encoder: Any,
|
||||
ids: Optional[List[str]] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
@@ -49,7 +49,7 @@ def create_index(
|
||||
]
|
||||
|
||||
# create dense vectors
|
||||
dense_embeds = embeddings.embed_texts(context_batch)
|
||||
dense_embeds = embeddings.embed_documents(context_batch)
|
||||
# create sparse vectors
|
||||
sparse_embeds = sparse_encoder.encode_documents(context_batch)
|
||||
for s in sparse_embeds:
|
||||
@@ -74,7 +74,7 @@ def create_index(
|
||||
|
||||
|
||||
class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
|
||||
embeddings: Embeddings
|
||||
embeddings: TextEmbeddingModel
|
||||
sparse_encoder: Any
|
||||
index: Any
|
||||
top_k: int = 4
|
||||
|
||||
@@ -10,17 +10,17 @@ from typing import Any, List, Optional
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
def create_index(contexts: List[str], embeddings: TextEmbeddingModel) -> np.ndarray:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
||||
|
||||
|
||||
class SVMRetriever(BaseRetriever, BaseModel):
|
||||
embeddings: Embeddings
|
||||
embeddings: TextEmbeddingModel
|
||||
index: Any
|
||||
texts: List[str]
|
||||
k: int = 4
|
||||
@@ -34,7 +34,7 @@ class SVMRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls, texts: List[str], embeddings: Embeddings, **kwargs: Any
|
||||
cls, texts: List[str], embeddings: TextEmbeddingModel, **kwargs: Any
|
||||
) -> SVMRetriever:
|
||||
index = create_index(texts, embeddings)
|
||||
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
|
||||
|
||||
@@ -109,7 +109,9 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
current_time = kwargs.get("current_time", datetime.datetime.now())
|
||||
current_time = kwargs.get("current_time")
|
||||
if current_time is None:
|
||||
current_time = datetime.datetime.now()
|
||||
# Avoid mutating input documents
|
||||
dup_docs = [deepcopy(d) for d in documents]
|
||||
for i, doc in enumerate(dup_docs):
|
||||
|
||||
@@ -23,7 +23,7 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Zilliz Retriever"""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores.zilliz import Zilliz
|
||||
|
||||
@@ -11,7 +11,7 @@ from langchain.vectorstores.zilliz import Zilliz
|
||||
class ZillizRetreiver(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Embeddings,
|
||||
embedding_function: TextEmbeddingModel,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: Optional[Dict[str, Any]] = None,
|
||||
consistency_level: str = "Session",
|
||||
|
||||
@@ -64,10 +64,12 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
documents.append(new_doc)
|
||||
return documents
|
||||
|
||||
def split_documents(self, documents: List[Document]) -> List[Document]:
|
||||
def split_documents(self, documents: Iterable[Document]) -> List[Document]:
|
||||
"""Split documents."""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
texts, metadatas = [], []
|
||||
for doc in documents:
|
||||
texts.append(doc.page_content)
|
||||
metadatas.append(doc.metadata)
|
||||
return self.create_documents(texts, metadatas=metadatas)
|
||||
|
||||
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
|
||||
@@ -154,7 +156,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
@@ -232,7 +234,7 @@ class TokenTextSplitter(TextSplitter):
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to for TokenTextSplitter. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
"""Core toolkit implementations."""
|
||||
|
||||
from langchain.tools.azure_cognitive_services import (
|
||||
AzureCogsFormRecognizerTool,
|
||||
AzureCogsImageAnalysisTool,
|
||||
AzureCogsSpeech2TextTool,
|
||||
AzureCogsText2SpeechTool,
|
||||
)
|
||||
from langchain.tools.base import BaseTool, StructuredTool, Tool, tool
|
||||
from langchain.tools.bing_search.tool import BingSearchResults, BingSearchRun
|
||||
from langchain.tools.ddg_search.tool import DuckDuckGoSearchResults, DuckDuckGoSearchRun
|
||||
@@ -56,6 +62,10 @@ from langchain.tools.zapier.tool import ZapierNLAListActions, ZapierNLARunAction
|
||||
__all__ = [
|
||||
"AIPluginTool",
|
||||
"APIOperation",
|
||||
"AzureCogsFormRecognizerTool",
|
||||
"AzureCogsImageAnalysisTool",
|
||||
"AzureCogsSpeech2TextTool",
|
||||
"AzureCogsText2SpeechTool",
|
||||
"BaseTool",
|
||||
"BaseTool",
|
||||
"BaseTool",
|
||||
|
||||
21
langchain/tools/azure_cognitive_services/__init__.py
Normal file
21
langchain/tools/azure_cognitive_services/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Azure Cognitive Services Tools."""
|
||||
|
||||
from langchain.tools.azure_cognitive_services.form_recognizer import (
|
||||
AzureCogsFormRecognizerTool,
|
||||
)
|
||||
from langchain.tools.azure_cognitive_services.image_analysis import (
|
||||
AzureCogsImageAnalysisTool,
|
||||
)
|
||||
from langchain.tools.azure_cognitive_services.speech2text import (
|
||||
AzureCogsSpeech2TextTool,
|
||||
)
|
||||
from langchain.tools.azure_cognitive_services.text2speech import (
|
||||
AzureCogsText2SpeechTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureCogsImageAnalysisTool",
|
||||
"AzureCogsFormRecognizerTool",
|
||||
"AzureCogsSpeech2TextTool",
|
||||
"AzureCogsText2SpeechTool",
|
||||
]
|
||||
152
langchain/tools/azure_cognitive_services/form_recognizer.py
Normal file
152
langchain/tools/azure_cognitive_services/form_recognizer.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.azure_cognitive_services.utils import detect_file_src_type
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureCogsFormRecognizerTool(BaseTool):
|
||||
"""Tool that queries the Azure Cognitive Services Form Recognizer API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://learn.microsoft.com/en-us/azure/applied-ai-services/form-recognizer/quickstarts/get-started-sdks-rest-api?view=form-recog-3.0.0&pivots=programming-language-python
|
||||
"""
|
||||
|
||||
azure_cogs_key: str = "" #: :meta private:
|
||||
azure_cogs_endpoint: str = "" #: :meta private:
|
||||
doc_analysis_client: Any #: :meta private:
|
||||
|
||||
name = "Azure Cognitive Services Form Recognizer"
|
||||
description = (
|
||||
"A wrapper around Azure Cognitive Services Form Recognizer. "
|
||||
"Useful for when you need to "
|
||||
"extract text, tables, and key-value pairs from documents. "
|
||||
"Input should be a url to a document."
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
azure_cogs_key = get_from_dict_or_env(
|
||||
values, "azure_cogs_key", "AZURE_COGS_KEY"
|
||||
)
|
||||
|
||||
azure_cogs_endpoint = get_from_dict_or_env(
|
||||
values, "azure_cogs_endpoint", "AZURE_COGS_ENDPOINT"
|
||||
)
|
||||
|
||||
try:
|
||||
from azure.ai.formrecognizer import DocumentAnalysisClient
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
values["doc_analysis_client"] = DocumentAnalysisClient(
|
||||
endpoint=azure_cogs_endpoint,
|
||||
credential=AzureKeyCredential(azure_cogs_key),
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-ai-formrecognizer is not installed. "
|
||||
"Run `pip install azure-ai-formrecognizer` to install."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _parse_tables(self, tables: List[Any]) -> List[Any]:
|
||||
result = []
|
||||
for table in tables:
|
||||
rc, cc = table.row_count, table.column_count
|
||||
_table = [["" for _ in range(cc)] for _ in range(rc)]
|
||||
for cell in table.cells:
|
||||
_table[cell.row_index][cell.column_index] = cell.content
|
||||
result.append(_table)
|
||||
return result
|
||||
|
||||
def _parse_kv_pairs(self, kv_pairs: List[Any]) -> List[Any]:
|
||||
result = []
|
||||
for kv_pair in kv_pairs:
|
||||
key = kv_pair.key.content if kv_pair.key else ""
|
||||
value = kv_pair.value.content if kv_pair.value else ""
|
||||
result.append((key, value))
|
||||
return result
|
||||
|
||||
def _document_analysis(self, document_path: str) -> Dict:
|
||||
document_src_type = detect_file_src_type(document_path)
|
||||
if document_src_type == "local":
|
||||
with open(document_path, "rb") as document:
|
||||
poller = self.doc_analysis_client.begin_analyze_document(
|
||||
"prebuilt-document", document
|
||||
)
|
||||
elif document_src_type == "remote":
|
||||
poller = self.doc_analysis_client.begin_analyze_document_from_url(
|
||||
"prebuilt-document", document_path
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid document path: {document_path}")
|
||||
|
||||
result = poller.result()
|
||||
res_dict = {}
|
||||
|
||||
if result.content is not None:
|
||||
res_dict["content"] = result.content
|
||||
|
||||
if result.tables is not None:
|
||||
res_dict["tables"] = self._parse_tables(result.tables)
|
||||
|
||||
if result.key_value_pairs is not None:
|
||||
res_dict["key_value_pairs"] = self._parse_kv_pairs(result.key_value_pairs)
|
||||
|
||||
return res_dict
|
||||
|
||||
def _format_document_analysis_result(self, document_analysis_result: Dict) -> str:
|
||||
formatted_result = []
|
||||
if "content" in document_analysis_result:
|
||||
formatted_result.append(
|
||||
f"Content: {document_analysis_result['content']}".replace("\n", " ")
|
||||
)
|
||||
|
||||
if "tables" in document_analysis_result:
|
||||
for i, table in enumerate(document_analysis_result["tables"]):
|
||||
formatted_result.append(f"Table {i}: {table}".replace("\n", " "))
|
||||
|
||||
if "key_value_pairs" in document_analysis_result:
|
||||
for kv_pair in document_analysis_result["key_value_pairs"]:
|
||||
formatted_result.append(
|
||||
f"{kv_pair[0]}: {kv_pair[1]}".replace("\n", " ")
|
||||
)
|
||||
|
||||
return "\n".join(formatted_result)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
document_analysis_result = self._document_analysis(query)
|
||||
if not document_analysis_result:
|
||||
return "No good document analysis result was found"
|
||||
|
||||
return self._format_document_analysis_result(document_analysis_result)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while running AzureCogsFormRecognizerTool: {e}")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("AzureCogsFormRecognizerTool does not support async")
|
||||
156
langchain/tools/azure_cognitive_services/image_analysis.py
Normal file
156
langchain/tools/azure_cognitive_services/image_analysis.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.azure_cognitive_services.utils import detect_file_src_type
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureCogsImageAnalysisTool(BaseTool):
|
||||
"""Tool that queries the Azure Cognitive Services Image Analysis API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://learn.microsoft.com/en-us/azure/cognitive-services/computer-vision/quickstarts-sdk/image-analysis-client-library-40
|
||||
"""
|
||||
|
||||
azure_cogs_key: str = "" #: :meta private:
|
||||
azure_cogs_endpoint: str = "" #: :meta private:
|
||||
vision_service: Any #: :meta private:
|
||||
analysis_options: Any #: :meta private:
|
||||
|
||||
name = "Azure Cognitive Services Image Analysis"
|
||||
description = (
|
||||
"A wrapper around Azure Cognitive Services Image Analysis. "
|
||||
"Useful for when you need to analyze images. "
|
||||
"Input should be a url to an image."
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
azure_cogs_key = get_from_dict_or_env(
|
||||
values, "azure_cogs_key", "AZURE_COGS_KEY"
|
||||
)
|
||||
|
||||
azure_cogs_endpoint = get_from_dict_or_env(
|
||||
values, "azure_cogs_endpoint", "AZURE_COGS_ENDPOINT"
|
||||
)
|
||||
|
||||
try:
|
||||
import azure.ai.vision as sdk
|
||||
|
||||
values["vision_service"] = sdk.VisionServiceOptions(
|
||||
endpoint=azure_cogs_endpoint, key=azure_cogs_key
|
||||
)
|
||||
|
||||
values["analysis_options"] = sdk.ImageAnalysisOptions()
|
||||
values["analysis_options"].features = (
|
||||
sdk.ImageAnalysisFeature.CAPTION
|
||||
| sdk.ImageAnalysisFeature.OBJECTS
|
||||
| sdk.ImageAnalysisFeature.TAGS
|
||||
| sdk.ImageAnalysisFeature.TEXT
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-ai-vision is not installed. "
|
||||
"Run `pip install azure-ai-vision` to install."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _image_analysis(self, image_path: str) -> Dict:
|
||||
try:
|
||||
import azure.ai.vision as sdk
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
image_src_type = detect_file_src_type(image_path)
|
||||
if image_src_type == "local":
|
||||
vision_source = sdk.VisionSource(filename=image_path)
|
||||
elif image_src_type == "remote":
|
||||
vision_source = sdk.VisionSource(url=image_path)
|
||||
else:
|
||||
raise ValueError(f"Invalid image path: {image_path}")
|
||||
|
||||
image_analyzer = sdk.ImageAnalyzer(
|
||||
self.vision_service, vision_source, self.analysis_options
|
||||
)
|
||||
result = image_analyzer.analyze()
|
||||
|
||||
res_dict = {}
|
||||
if result.reason == sdk.ImageAnalysisResultReason.ANALYZED:
|
||||
if result.caption is not None:
|
||||
res_dict["caption"] = result.caption.content
|
||||
|
||||
if result.objects is not None:
|
||||
res_dict["objects"] = [obj.name for obj in result.objects]
|
||||
|
||||
if result.tags is not None:
|
||||
res_dict["tags"] = [tag.name for tag in result.tags]
|
||||
|
||||
if result.text is not None:
|
||||
res_dict["text"] = [line.content for line in result.text.lines]
|
||||
|
||||
else:
|
||||
error_details = sdk.ImageAnalysisErrorDetails.from_result(result)
|
||||
raise RuntimeError(
|
||||
f"Image analysis failed.\n"
|
||||
f"Reason: {error_details.reason}\n"
|
||||
f"Details: {error_details.message}"
|
||||
)
|
||||
|
||||
return res_dict
|
||||
|
||||
def _format_image_analysis_result(self, image_analysis_result: Dict) -> str:
|
||||
formatted_result = []
|
||||
if "caption" in image_analysis_result:
|
||||
formatted_result.append("Caption: " + image_analysis_result["caption"])
|
||||
|
||||
if (
|
||||
"objects" in image_analysis_result
|
||||
and len(image_analysis_result["objects"]) > 0
|
||||
):
|
||||
formatted_result.append(
|
||||
"Objects: " + ", ".join(image_analysis_result["objects"])
|
||||
)
|
||||
|
||||
if "tags" in image_analysis_result and len(image_analysis_result["tags"]) > 0:
|
||||
formatted_result.append("Tags: " + ", ".join(image_analysis_result["tags"]))
|
||||
|
||||
if "text" in image_analysis_result and len(image_analysis_result["text"]) > 0:
|
||||
formatted_result.append("Text: " + ", ".join(image_analysis_result["text"]))
|
||||
|
||||
return "\n".join(formatted_result)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
image_analysis_result = self._image_analysis(query)
|
||||
if not image_analysis_result:
|
||||
return "No good image analysis result was found"
|
||||
|
||||
return self._format_image_analysis_result(image_analysis_result)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while running AzureCogsImageAnalysisTool: {e}")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("AzureCogsImageAnalysisTool does not support async")
|
||||
131
langchain/tools/azure_cognitive_services/speech2text.py
Normal file
131
langchain/tools/azure_cognitive_services/speech2text.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.azure_cognitive_services.utils import (
|
||||
detect_file_src_type,
|
||||
download_audio_from_url,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureCogsSpeech2TextTool(BaseTool):
|
||||
"""Tool that queries the Azure Cognitive Services Speech2Text API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/get-started-speech-to-text?pivots=programming-language-python
|
||||
"""
|
||||
|
||||
azure_cogs_key: str = "" #: :meta private:
|
||||
azure_cogs_region: str = "" #: :meta private:
|
||||
speech_language: str = "en-US" #: :meta private:
|
||||
speech_config: Any #: :meta private:
|
||||
|
||||
name = "Azure Cognitive Services Speech2Text"
|
||||
description = (
|
||||
"A wrapper around Azure Cognitive Services Speech2Text. "
|
||||
"Useful for when you need to transcribe audio to text. "
|
||||
"Input should be a url to an audio file."
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
azure_cogs_key = get_from_dict_or_env(
|
||||
values, "azure_cogs_key", "AZURE_COGS_KEY"
|
||||
)
|
||||
|
||||
azure_cogs_region = get_from_dict_or_env(
|
||||
values, "azure_cogs_region", "AZURE_COGS_REGION"
|
||||
)
|
||||
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
|
||||
values["speech_config"] = speechsdk.SpeechConfig(
|
||||
subscription=azure_cogs_key, region=azure_cogs_region
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-cognitiveservices-speech is not installed. "
|
||||
"Run `pip install azure-cognitiveservices-speech` to install."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _continuous_recognize(self, speech_recognizer: Any) -> str:
|
||||
done = False
|
||||
text = ""
|
||||
|
||||
def stop_cb(evt: Any) -> None:
|
||||
"""callback that stop continuous recognition"""
|
||||
speech_recognizer.stop_continuous_recognition_async()
|
||||
nonlocal done
|
||||
done = True
|
||||
|
||||
def retrieve_cb(evt: Any) -> None:
|
||||
"""callback that retrieves the intermediate recognition results"""
|
||||
nonlocal text
|
||||
text += evt.result.text
|
||||
|
||||
# retrieve text on recognized events
|
||||
speech_recognizer.recognized.connect(retrieve_cb)
|
||||
# stop continuous recognition on either session stopped or canceled events
|
||||
speech_recognizer.session_stopped.connect(stop_cb)
|
||||
speech_recognizer.canceled.connect(stop_cb)
|
||||
|
||||
# Start continuous speech recognition
|
||||
speech_recognizer.start_continuous_recognition_async()
|
||||
while not done:
|
||||
time.sleep(0.5)
|
||||
return text
|
||||
|
||||
def _speech2text(self, audio_path: str, speech_language: str) -> str:
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
audio_src_type = detect_file_src_type(audio_path)
|
||||
if audio_src_type == "local":
|
||||
audio_config = speechsdk.AudioConfig(filename=audio_path)
|
||||
elif audio_src_type == "remote":
|
||||
tmp_audio_path = download_audio_from_url(audio_path)
|
||||
audio_config = speechsdk.AudioConfig(filename=tmp_audio_path)
|
||||
else:
|
||||
raise ValueError(f"Invalid audio path: {audio_path}")
|
||||
|
||||
self.speech_config.speech_recognition_language = speech_language
|
||||
speech_recognizer = speechsdk.SpeechRecognizer(self.speech_config, audio_config)
|
||||
return self._continuous_recognize(speech_recognizer)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
text = self._speech2text(query, self.speech_language)
|
||||
return text
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while running AzureCogsSpeech2TextTool: {e}")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("AzureCogsSpeech2TextTool does not support async")
|
||||
114
langchain/tools/azure_cognitive_services/text2speech.py
Normal file
114
langchain/tools/azure_cognitive_services/text2speech.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureCogsText2SpeechTool(BaseTool):
|
||||
"""Tool that queries the Azure Cognitive Services Text2Speech API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/get-started-text-to-speech?pivots=programming-language-python
|
||||
"""
|
||||
|
||||
azure_cogs_key: str = "" #: :meta private:
|
||||
azure_cogs_region: str = "" #: :meta private:
|
||||
speech_language: str = "en-US" #: :meta private:
|
||||
speech_config: Any #: :meta private:
|
||||
|
||||
name = "Azure Cognitive Services Text2Speech"
|
||||
description = (
|
||||
"A wrapper around Azure Cognitive Services Text2Speech. "
|
||||
"Useful for when you need to convert text to speech. "
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
azure_cogs_key = get_from_dict_or_env(
|
||||
values, "azure_cogs_key", "AZURE_COGS_KEY"
|
||||
)
|
||||
|
||||
azure_cogs_region = get_from_dict_or_env(
|
||||
values, "azure_cogs_region", "AZURE_COGS_REGION"
|
||||
)
|
||||
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
|
||||
values["speech_config"] = speechsdk.SpeechConfig(
|
||||
subscription=azure_cogs_key, region=azure_cogs_region
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-cognitiveservices-speech is not installed. "
|
||||
"Run `pip install azure-cognitiveservices-speech` to install."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _text2speech(self, text: str, speech_language: str) -> str:
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
self.speech_config.speech_synthesis_language = speech_language
|
||||
speech_synthesizer = speechsdk.SpeechSynthesizer(
|
||||
speech_config=self.speech_config, audio_config=None
|
||||
)
|
||||
result = speech_synthesizer.speak_text(text)
|
||||
|
||||
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
|
||||
stream = speechsdk.AudioDataStream(result)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="wb", suffix=".wav", delete=False
|
||||
) as f:
|
||||
stream.save_to_wav_file(f.name)
|
||||
|
||||
return f.name
|
||||
|
||||
elif result.reason == speechsdk.ResultReason.Canceled:
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.debug(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
if cancellation_details.reason == speechsdk.CancellationReason.Error:
|
||||
raise RuntimeError(
|
||||
f"Speech synthesis error: {cancellation_details.error_details}"
|
||||
)
|
||||
|
||||
return "Speech synthesis canceled."
|
||||
|
||||
else:
|
||||
return f"Speech synthesis failed: {result.reason}"
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
speech_file = self._text2speech(query, self.speech_language)
|
||||
return speech_file
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while running AzureCogsText2SpeechTool: {e}")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("AzureCogsText2SpeechTool does not support async")
|
||||
29
langchain/tools/azure_cognitive_services/utils.py
Normal file
29
langchain/tools/azure_cognitive_services/utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
import tempfile
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def detect_file_src_type(file_path: str) -> str:
|
||||
"""Detect if the file is local or remote."""
|
||||
if os.path.isfile(file_path):
|
||||
return "local"
|
||||
|
||||
parsed_url = urlparse(file_path)
|
||||
if parsed_url.scheme and parsed_url.netloc:
|
||||
return "remote"
|
||||
|
||||
return "invalid"
|
||||
|
||||
|
||||
def download_audio_from_url(audio_url: str) -> str:
|
||||
"""Download audio from url to local."""
|
||||
ext = audio_url.split(".")[-1]
|
||||
response = requests.get(audio_url, stream=True)
|
||||
response.raise_for_status()
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=f".{ext}", delete=False) as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
return f.name
|
||||
@@ -35,7 +35,7 @@ class LambdaWrapper(BaseModel):
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"boto3 is not installed." "Please install it with `pip install boto3`"
|
||||
"boto3 is not installed. Please install it with `pip install boto3`"
|
||||
)
|
||||
|
||||
values["lambda_client"] = boto3.client("lambda")
|
||||
|
||||
@@ -51,8 +51,8 @@ class GooglePlacesAPIWrapper(BaseModel):
|
||||
|
||||
values["google_map_client"] = googlemaps.Client(gplaces_api_key)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import googlemaps python packge. "
|
||||
raise ImportError(
|
||||
"Could not import googlemaps python package. "
|
||||
"Please install it with `pip install googlemaps`."
|
||||
)
|
||||
return values
|
||||
|
||||
@@ -156,7 +156,7 @@ class JiraAPIWrapper(BaseModel):
|
||||
import json
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"json is not installed. " "Please install it with `pip install json`"
|
||||
"json is not installed. Please install it with `pip install json`"
|
||||
)
|
||||
params = json.loads(query)
|
||||
return self.jira.issue_create(fields=dict(params))
|
||||
|
||||
@@ -38,7 +38,7 @@ class OpenWeatherMapAPIWrapper(BaseModel):
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pyowm is not installed. " "Please install it with `pip install pyowm`"
|
||||
"pyowm is not installed. Please install it with `pip install pyowm`"
|
||||
)
|
||||
|
||||
owm = pyowm.OWM(openweathermap_api_key)
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy.orm import Session, relationship
|
||||
from sqlalchemy.sql.expression import func
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
@@ -126,7 +126,7 @@ class AnalyticDB(VectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
embedding_function: Embeddings,
|
||||
embedding_function: TextEmbeddingModel,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
collection_metadata: Optional[dict] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
@@ -202,7 +202,7 @@ class AnalyticDB(VectorStore):
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid1()) for _ in texts]
|
||||
|
||||
embeddings = self.embedding_function.embed_texts(list(texts))
|
||||
embeddings = self.embedding_function.embed_documents(list(texts))
|
||||
|
||||
if not metadatas:
|
||||
metadatas = [{} for _ in texts]
|
||||
@@ -343,7 +343,7 @@ class AnalyticDB(VectorStore):
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
embedding: TextEmbeddingModel,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
ids: Optional[List[str]] = None,
|
||||
@@ -390,7 +390,7 @@ class AnalyticDB(VectorStore):
|
||||
def from_documents(
|
||||
cls,
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
embedding: TextEmbeddingModel,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
ids: Optional[List[str]] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
|
||||
@@ -13,7 +13,7 @@ import numpy as np
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
@@ -282,7 +282,7 @@ class Annoy(VectorStore):
|
||||
cls,
|
||||
texts: List[str],
|
||||
embeddings: List[List[float]],
|
||||
embedding: Embeddings,
|
||||
embedding: TextEmbeddingModel,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metric: str = DEFAULT_METRIC,
|
||||
trees: int = 100,
|
||||
@@ -319,7 +319,7 @@ class Annoy(VectorStore):
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
embedding: TextEmbeddingModel,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metric: str = DEFAULT_METRIC,
|
||||
trees: int = 100,
|
||||
@@ -351,7 +351,7 @@ class Annoy(VectorStore):
|
||||
embeddings = OpenAIEmbeddings()
|
||||
index = Annoy.from_texts(texts, embeddings)
|
||||
"""
|
||||
embeddings = embedding.embed_texts(texts)
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
return cls.__from(
|
||||
texts, embeddings, embedding, metadatas, metric, trees, n_jobs, **kwargs
|
||||
)
|
||||
@@ -360,7 +360,7 @@ class Annoy(VectorStore):
|
||||
def from_embeddings(
|
||||
cls,
|
||||
text_embeddings: List[Tuple[str, List[float]]],
|
||||
embedding: Embeddings,
|
||||
embedding: TextEmbeddingModel,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metric: str = DEFAULT_METRIC,
|
||||
trees: int = 100,
|
||||
@@ -424,7 +424,7 @@ class Annoy(VectorStore):
|
||||
def load_local(
|
||||
cls,
|
||||
folder_path: str,
|
||||
embeddings: Embeddings,
|
||||
embeddings: TextEmbeddingModel,
|
||||
) -> Annoy:
|
||||
"""Load Annoy index, docstore, and index_to_docstore_id to disk.
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, Iterable, List, Optional, Type
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -34,7 +34,7 @@ class AtlasDB(VectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
embedding_function: Optional[Embeddings] = None,
|
||||
embedding_function: Optional[TextEmbeddingModel] = None,
|
||||
api_key: Optional[str] = None,
|
||||
description: str = "A description for your project",
|
||||
is_public: bool = True,
|
||||
@@ -119,7 +119,7 @@ class AtlasDB(VectorStore):
|
||||
|
||||
# Embedding upload case
|
||||
if self._embedding_function is not None:
|
||||
_embeddings = self._embedding_function.embed_texts(texts)
|
||||
_embeddings = self._embedding_function.embed_documents(texts)
|
||||
embeddings = np.stack(_embeddings)
|
||||
if metadatas is None:
|
||||
data = [
|
||||
@@ -194,7 +194,7 @@ class AtlasDB(VectorStore):
|
||||
"AtlasDB requires an embedding_function for text similarity search!"
|
||||
)
|
||||
|
||||
_embedding = self._embedding_function.embed_texts([query])[0]
|
||||
_embedding = self._embedding_function.embed_documents([query])[0]
|
||||
embedding = np.array(_embedding).reshape(1, -1)
|
||||
with self.project.wait_for_project_lock():
|
||||
neighbors, _ = self.project.projections[0].vector_search(
|
||||
@@ -212,7 +212,7 @@ class AtlasDB(VectorStore):
|
||||
def from_texts(
|
||||
cls: Type[AtlasDB],
|
||||
texts: List[str],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
embedding: Optional[TextEmbeddingModel] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
name: Optional[str] = None,
|
||||
@@ -229,7 +229,7 @@ class AtlasDB(VectorStore):
|
||||
texts (List[str]): The list of texts to ingest.
|
||||
name (str): Name of the project to create.
|
||||
api_key (str): Your nomic API key,
|
||||
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||
embedding (Optional[TextEmbeddingModel]): Embedding function. Defaults to None.
|
||||
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
|
||||
ids (Optional[List[str]]): Optional list of document IDs. If None,
|
||||
ids will be auto created
|
||||
@@ -272,7 +272,7 @@ class AtlasDB(VectorStore):
|
||||
def from_documents(
|
||||
cls: Type[AtlasDB],
|
||||
documents: List[Document],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
embedding: Optional[TextEmbeddingModel] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
name: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
@@ -289,7 +289,7 @@ class AtlasDB(VectorStore):
|
||||
name (str): Name of the collection to create.
|
||||
api_key (str): Your nomic API key,
|
||||
documents (List[Document]): List of documents to add to the vectorstore.
|
||||
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||
embedding (Optional[TextEmbeddingModel]): Embedding function. Defaults to None.
|
||||
ids (Optional[List[str]]): Optional list of document IDs. If None,
|
||||
ids will be auto created
|
||||
description (str): A description for your project.
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
VST = TypeVar("VST", bound="VectorStore")
|
||||
@@ -298,7 +298,7 @@ class VectorStore(ABC):
|
||||
def from_documents(
|
||||
cls: Type[VST],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
embedding: TextEmbeddingModel,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from documents and embeddings."""
|
||||
@@ -310,7 +310,7 @@ class VectorStore(ABC):
|
||||
async def afrom_documents(
|
||||
cls: Type[VST],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
embedding: TextEmbeddingModel,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Return VectorStore initialized from documents and embeddings."""
|
||||
@@ -323,7 +323,7 @@ class VectorStore(ABC):
|
||||
def from_texts(
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
embedding: TextEmbeddingModel,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
@@ -333,7 +333,7 @@ class VectorStore(ABC):
|
||||
async def afrom_texts(
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
embedding: TextEmbeddingModel,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Ty
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.utils import xor_args
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
@@ -58,7 +58,7 @@ class Chroma(VectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
embedding_function: Optional[Embeddings] = None,
|
||||
embedding_function: Optional[TextEmbeddingModel] = None,
|
||||
persist_directory: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
@@ -92,7 +92,7 @@ class Chroma(VectorStore):
|
||||
self._persist_directory = persist_directory
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=self._embedding_function.embed_texts
|
||||
embedding_function=self._embedding_function.embed_documents
|
||||
if self._embedding_function is not None
|
||||
else None,
|
||||
metadata=collection_metadata,
|
||||
@@ -156,7 +156,7 @@ class Chroma(VectorStore):
|
||||
ids = [str(uuid.uuid1()) for _ in texts]
|
||||
embeddings = None
|
||||
if self._embedding_function is not None:
|
||||
embeddings = self._embedding_function.embed_texts(list(texts))
|
||||
embeddings = self._embedding_function.embed_documents(list(texts))
|
||||
self._collection.add(
|
||||
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=ids
|
||||
)
|
||||
@@ -354,7 +354,7 @@ class Chroma(VectorStore):
|
||||
def from_texts(
|
||||
cls: Type[Chroma],
|
||||
texts: List[str],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
embedding: Optional[TextEmbeddingModel] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
@@ -372,7 +372,7 @@ class Chroma(VectorStore):
|
||||
texts (List[str]): List of texts to add to the collection.
|
||||
collection_name (str): Name of the collection to create.
|
||||
persist_directory (Optional[str]): Directory to persist the collection.
|
||||
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||
embedding (Optional[TextEmbeddingModel]): Embedding function. Defaults to None.
|
||||
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
|
||||
ids (Optional[List[str]]): List of document IDs. Defaults to None.
|
||||
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
|
||||
@@ -394,7 +394,7 @@ class Chroma(VectorStore):
|
||||
def from_documents(
|
||||
cls: Type[Chroma],
|
||||
documents: List[Document],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
embedding: Optional[TextEmbeddingModel] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
persist_directory: Optional[str] = None,
|
||||
@@ -412,7 +412,7 @@ class Chroma(VectorStore):
|
||||
persist_directory (Optional[str]): Directory to persist the collection.
|
||||
ids (Optional[List[str]]): List of document IDs. Defaults to None.
|
||||
documents (List[Document]): List of documents to add to the vectorstore.
|
||||
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||
embedding (Optional[TextEmbeddingModel]): Embedding function. Defaults to None.
|
||||
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
|
||||
Returns:
|
||||
Chroma: Chroma vectorstore.
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tupl
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.base import TextEmbeddingModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
@@ -96,7 +96,7 @@ class DeepLake(VectorStore):
|
||||
self,
|
||||
dataset_path: str = _LANGCHAIN_DEFAULT_DEEPLAKE_PATH,
|
||||
token: Optional[str] = None,
|
||||
embedding_function: Optional[Embeddings] = None,
|
||||
embedding_function: Optional[TextEmbeddingModel] = None,
|
||||
read_only: Optional[bool] = False,
|
||||
ingestion_batch_size: int = 1024,
|
||||
num_workers: int = 0,
|
||||
@@ -224,7 +224,7 @@ class DeepLake(VectorStore):
|
||||
embeds: Sequence[Optional[np.ndarray]] = []
|
||||
|
||||
if self._embedding_function is not None:
|
||||
embeddings = self._embedding_function.embed_texts(text_list)
|
||||
embeddings = self._embedding_function.embed_documents(text_list)
|
||||
embeds = [np.array(e, dtype=np.float32) for e in embeddings]
|
||||
else:
|
||||
embeds = [None] * len(text_list)
|
||||
@@ -494,7 +494,7 @@ class DeepLake(VectorStore):
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Optional[Embeddings] = None,
|
||||
embedding: Optional[TextEmbeddingModel] = None,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
dataset_path: str = _LANGCHAIN_DEFAULT_DEEPLAKE_PATH,
|
||||
@@ -522,7 +522,7 @@ class DeepLake(VectorStore):
|
||||
save the dataset, but keeps it in memory instead.
|
||||
Should be used only for testing as it does not persist.
|
||||
documents (List[Document]): List of documents to add.
|
||||
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||
embedding (Optional[TextEmbeddingModel]): Embedding function. Defaults to None.
|
||||
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
|
||||
ids (Optional[List[str]]): List of document IDs. Defaults to None.
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user